Rename PositionAwareInitializer to DataAwareInitializer
Also, add the aliases `Zeros` and `Ones`.
This commit is contained in:
parent
0055e15bc1
commit
736d9a6349
@ -62,19 +62,19 @@ class UniformInitializer(DimensionAwareInitializer):
|
|||||||
return torch.ones(gen_dims).uniform_(self.min, self.max)
|
return torch.ones(gen_dims).uniform_(self.min, self.max)
|
||||||
|
|
||||||
|
|
||||||
class PositionAwareInitializer(ComponentsInitializer):
|
class DataAwareInitializer(ComponentsInitializer):
|
||||||
def __init__(self, positions):
|
def __init__(self, data):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.data = positions
|
self.data = data
|
||||||
|
|
||||||
|
|
||||||
class SelectionInitializer(PositionAwareInitializer):
|
class SelectionInitializer(DataAwareInitializer):
|
||||||
def generate(self, length):
|
def generate(self, length):
|
||||||
indices = torch.LongTensor(length).random_(0, len(self.data))
|
indices = torch.LongTensor(length).random_(0, len(self.data))
|
||||||
return self.data[indices]
|
return self.data[indices]
|
||||||
|
|
||||||
|
|
||||||
class MeanInitializer(PositionAwareInitializer):
|
class MeanInitializer(DataAwareInitializer):
|
||||||
def generate(self, length):
|
def generate(self, length):
|
||||||
mean = torch.mean(self.data, dim=0)
|
mean = torch.mean(self.data, dim=0)
|
||||||
repeat_dim = [length] + [1] * len(mean.shape)
|
repeat_dim = [length] + [1] * len(mean.shape)
|
||||||
@ -205,3 +205,5 @@ class ZeroReasoningsInitializer(ReasoningsInitializer):
|
|||||||
SSI = StratifiedSampleInitializer = StratifiedSelectionInitializer
|
SSI = StratifiedSampleInitializer = StratifiedSelectionInitializer
|
||||||
SMI = StratifiedMeanInitializer
|
SMI = StratifiedMeanInitializer
|
||||||
Random = RandomInitializer = UniformInitializer
|
Random = RandomInitializer = UniformInitializer
|
||||||
|
Zeros = ZerosInitializer
|
||||||
|
Ones = OnesInitializer
|
||||||
|
Loading…
Reference in New Issue
Block a user