[FEATURE] Optional transforms in DataAwareInitializers

This commit is contained in:
Jensun Ravichandran 2021-06-04 22:14:45 +02:00
parent 8200e1d3d8
commit 827958a28a

View File

@ -84,33 +84,33 @@ class UniformInitializer(DimensionAwareInitializer):
class DataAwareInitializer(ComponentsInitializer):
def __init__(self, data):
def __init__(self, data, transform=torch.nn.Identity()):
super().__init__()
self.data = data
self.transform = transform
def __del__(self):
del self.data
class SelectionInitializer(DataAwareInitializer):
def generate(self, length):
indices = torch.LongTensor(length).random_(0, len(self.data))
return self.data[indices]
return self.transform(self.data[indices])
class MeanInitializer(DataAwareInitializer):
def generate(self, length):
mean = torch.mean(self.data, dim=0)
repeat_dim = [length] + [1] * len(mean.shape)
return mean.repeat(repeat_dim)
return self.transform(mean.repeat(repeat_dim))
class ClassAwareInitializer(ComponentsInitializer):
class ClassAwareInitializer(DataAwareInitializer):
def __init__(self, data, transform=torch.nn.Identity()):
super().__init__()
data, targets = parse_data_arg(data)
self.data = data
super().__init__(data, transform)
self.targets = targets
self.transform = transform
self.clabels = torch.unique(self.targets).int().tolist()
self.num_classes = len(self.clabels)