[FEATURE] Optional transforms in DataAwareInitializers

This commit is contained in:
Jensun Ravichandran 2021-06-04 22:14:45 +02:00
parent 8200e1d3d8
commit 827958a28a

View File

@ -84,33 +84,33 @@ class UniformInitializer(DimensionAwareInitializer):
class DataAwareInitializer(ComponentsInitializer): class DataAwareInitializer(ComponentsInitializer):
def __init__(self, data): def __init__(self, data, transform=torch.nn.Identity()):
super().__init__() super().__init__()
self.data = data self.data = data
self.transform = transform
def __del__(self):
del self.data
class SelectionInitializer(DataAwareInitializer): class SelectionInitializer(DataAwareInitializer):
def generate(self, length): def generate(self, length):
indices = torch.LongTensor(length).random_(0, len(self.data)) indices = torch.LongTensor(length).random_(0, len(self.data))
return self.data[indices] return self.transform(self.data[indices])
class MeanInitializer(DataAwareInitializer): class MeanInitializer(DataAwareInitializer):
def generate(self, length): def generate(self, length):
mean = torch.mean(self.data, dim=0) mean = torch.mean(self.data, dim=0)
repeat_dim = [length] + [1] * len(mean.shape) repeat_dim = [length] + [1] * len(mean.shape)
return mean.repeat(repeat_dim) return self.transform(mean.repeat(repeat_dim))
class ClassAwareInitializer(ComponentsInitializer): class ClassAwareInitializer(DataAwareInitializer):
def __init__(self, data, transform=torch.nn.Identity()): def __init__(self, data, transform=torch.nn.Identity()):
super().__init__()
data, targets = parse_data_arg(data) data, targets = parse_data_arg(data)
self.data = data super().__init__(data, transform)
self.targets = targets self.targets = targets
self.transform = transform
self.clabels = torch.unique(self.targets).int().tolist() self.clabels = torch.unique(self.targets).int().tolist()
self.num_classes = len(self.clabels) self.num_classes = len(self.clabels)