[FEATURE] Optional transforms in DataAwareInitializers
This commit is contained in:
parent
8200e1d3d8
commit
827958a28a
@ -84,33 +84,33 @@ class UniformInitializer(DimensionAwareInitializer):
|
||||
|
||||
|
||||
class DataAwareInitializer(ComponentsInitializer):
|
||||
def __init__(self, data):
|
||||
def __init__(self, data, transform=torch.nn.Identity()):
|
||||
super().__init__()
|
||||
self.data = data
|
||||
self.transform = transform
|
||||
|
||||
def __del__(self):
|
||||
del self.data
|
||||
|
||||
|
||||
class SelectionInitializer(DataAwareInitializer):
|
||||
def generate(self, length):
|
||||
indices = torch.LongTensor(length).random_(0, len(self.data))
|
||||
return self.data[indices]
|
||||
return self.transform(self.data[indices])
|
||||
|
||||
|
||||
class MeanInitializer(DataAwareInitializer):
|
||||
def generate(self, length):
|
||||
mean = torch.mean(self.data, dim=0)
|
||||
repeat_dim = [length] + [1] * len(mean.shape)
|
||||
return mean.repeat(repeat_dim)
|
||||
return self.transform(mean.repeat(repeat_dim))
|
||||
|
||||
|
||||
class ClassAwareInitializer(ComponentsInitializer):
|
||||
class ClassAwareInitializer(DataAwareInitializer):
|
||||
def __init__(self, data, transform=torch.nn.Identity()):
|
||||
super().__init__()
|
||||
data, targets = parse_data_arg(data)
|
||||
self.data = data
|
||||
super().__init__(data, transform)
|
||||
self.targets = targets
|
||||
|
||||
self.transform = transform
|
||||
|
||||
self.clabels = torch.unique(self.targets).int().tolist()
|
||||
self.num_classes = len(self.clabels)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user