[FEATURE] Optional transforms in DataAwareInitializers
This commit is contained in:
parent
8200e1d3d8
commit
827958a28a
@ -84,33 +84,33 @@ class UniformInitializer(DimensionAwareInitializer):
|
|||||||
|
|
||||||
|
|
||||||
class DataAwareInitializer(ComponentsInitializer):
|
class DataAwareInitializer(ComponentsInitializer):
|
||||||
def __init__(self, data):
|
def __init__(self, data, transform=torch.nn.Identity()):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.data = data
|
self.data = data
|
||||||
|
self.transform = transform
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
del self.data
|
||||||
|
|
||||||
|
|
||||||
class SelectionInitializer(DataAwareInitializer):
|
class SelectionInitializer(DataAwareInitializer):
|
||||||
def generate(self, length):
|
def generate(self, length):
|
||||||
indices = torch.LongTensor(length).random_(0, len(self.data))
|
indices = torch.LongTensor(length).random_(0, len(self.data))
|
||||||
return self.data[indices]
|
return self.transform(self.data[indices])
|
||||||
|
|
||||||
|
|
||||||
class MeanInitializer(DataAwareInitializer):
|
class MeanInitializer(DataAwareInitializer):
|
||||||
def generate(self, length):
|
def generate(self, length):
|
||||||
mean = torch.mean(self.data, dim=0)
|
mean = torch.mean(self.data, dim=0)
|
||||||
repeat_dim = [length] + [1] * len(mean.shape)
|
repeat_dim = [length] + [1] * len(mean.shape)
|
||||||
return mean.repeat(repeat_dim)
|
return self.transform(mean.repeat(repeat_dim))
|
||||||
|
|
||||||
|
|
||||||
class ClassAwareInitializer(ComponentsInitializer):
|
class ClassAwareInitializer(DataAwareInitializer):
|
||||||
def __init__(self, data, transform=torch.nn.Identity()):
|
def __init__(self, data, transform=torch.nn.Identity()):
|
||||||
super().__init__()
|
|
||||||
data, targets = parse_data_arg(data)
|
data, targets = parse_data_arg(data)
|
||||||
self.data = data
|
super().__init__(data, transform)
|
||||||
self.targets = targets
|
self.targets = targets
|
||||||
|
|
||||||
self.transform = transform
|
|
||||||
|
|
||||||
self.clabels = torch.unique(self.targets).int().tolist()
|
self.clabels = torch.unique(self.targets).int().tolist()
|
||||||
self.num_classes = len(self.clabels)
|
self.num_classes = len(self.clabels)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user