Rename PositionAwareInitializer to DataAwareInitializer

Also, add the aliases `Zeros` and `Ones`.
This commit is contained in:
Jensun Ravichandran 2021-05-18 19:37:25 +02:00
parent 0055e15bc1
commit 736d9a6349

View File

@ -62,19 +62,19 @@ class UniformInitializer(DimensionAwareInitializer):
return torch.ones(gen_dims).uniform_(self.min, self.max) return torch.ones(gen_dims).uniform_(self.min, self.max)
class PositionAwareInitializer(ComponentsInitializer): class DataAwareInitializer(ComponentsInitializer):
def __init__(self, positions): def __init__(self, data):
super().__init__() super().__init__()
self.data = positions self.data = data
class SelectionInitializer(PositionAwareInitializer): class SelectionInitializer(DataAwareInitializer):
def generate(self, length): def generate(self, length):
indices = torch.LongTensor(length).random_(0, len(self.data)) indices = torch.LongTensor(length).random_(0, len(self.data))
return self.data[indices] return self.data[indices]
class MeanInitializer(PositionAwareInitializer): class MeanInitializer(DataAwareInitializer):
def generate(self, length): def generate(self, length):
mean = torch.mean(self.data, dim=0) mean = torch.mean(self.data, dim=0)
repeat_dim = [length] + [1] * len(mean.shape) repeat_dim = [length] + [1] * len(mean.shape)
@ -205,3 +205,5 @@ class ZeroReasoningsInitializer(ReasoningsInitializer):
SSI = StratifiedSampleInitializer = StratifiedSelectionInitializer SSI = StratifiedSampleInitializer = StratifiedSelectionInitializer
SMI = StratifiedMeanInitializer SMI = StratifiedMeanInitializer
Random = RandomInitializer = UniformInitializer Random = RandomInitializer = UniformInitializer
Zeros = ZerosInitializer
Ones = OnesInitializer