[FEATURE] Optional transforms in DataAwareInitializers
This commit is contained in:
		@@ -84,33 +84,33 @@ class UniformInitializer(DimensionAwareInitializer):
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class DataAwareInitializer(ComponentsInitializer):
 | 
			
		||||
    def __init__(self, data):
 | 
			
		||||
    def __init__(self, data, transform=torch.nn.Identity()):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        self.data = data
 | 
			
		||||
        self.transform = transform
 | 
			
		||||
 | 
			
		||||
    def __del__(self):
 | 
			
		||||
        del self.data
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class SelectionInitializer(DataAwareInitializer):
 | 
			
		||||
    def generate(self, length):
 | 
			
		||||
        indices = torch.LongTensor(length).random_(0, len(self.data))
 | 
			
		||||
        return self.data[indices]
 | 
			
		||||
        return self.transform(self.data[indices])
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class MeanInitializer(DataAwareInitializer):
 | 
			
		||||
    def generate(self, length):
 | 
			
		||||
        mean = torch.mean(self.data, dim=0)
 | 
			
		||||
        repeat_dim = [length] + [1] * len(mean.shape)
 | 
			
		||||
        return mean.repeat(repeat_dim)
 | 
			
		||||
        return self.transform(mean.repeat(repeat_dim))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ClassAwareInitializer(ComponentsInitializer):
 | 
			
		||||
class ClassAwareInitializer(DataAwareInitializer):
 | 
			
		||||
    def __init__(self, data, transform=torch.nn.Identity()):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        data, targets = parse_data_arg(data)
 | 
			
		||||
        self.data = data
 | 
			
		||||
        super().__init__(data, transform)
 | 
			
		||||
        self.targets = targets
 | 
			
		||||
 | 
			
		||||
        self.transform = transform
 | 
			
		||||
 | 
			
		||||
        self.clabels = torch.unique(self.targets).int().tolist()
 | 
			
		||||
        self.num_classes = len(self.clabels)
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user