Rename PositionAwareInitializer to DataAwareInitializer
Also, add the aliases `Zeros` and `Ones`.
This commit is contained in:
parent
0055e15bc1
commit
736d9a6349
@ -62,19 +62,19 @@ class UniformInitializer(DimensionAwareInitializer):
|
||||
return torch.ones(gen_dims).uniform_(self.min, self.max)
|
||||
|
||||
|
||||
class PositionAwareInitializer(ComponentsInitializer):
|
||||
def __init__(self, positions):
|
||||
class DataAwareInitializer(ComponentsInitializer):
|
||||
def __init__(self, data):
|
||||
super().__init__()
|
||||
self.data = positions
|
||||
self.data = data
|
||||
|
||||
|
||||
class SelectionInitializer(PositionAwareInitializer):
|
||||
class SelectionInitializer(DataAwareInitializer):
|
||||
def generate(self, length):
|
||||
indices = torch.LongTensor(length).random_(0, len(self.data))
|
||||
return self.data[indices]
|
||||
|
||||
|
||||
class MeanInitializer(PositionAwareInitializer):
|
||||
class MeanInitializer(DataAwareInitializer):
|
||||
def generate(self, length):
|
||||
mean = torch.mean(self.data, dim=0)
|
||||
repeat_dim = [length] + [1] * len(mean.shape)
|
||||
@ -205,3 +205,5 @@ class ZeroReasoningsInitializer(ReasoningsInitializer):
|
||||
SSI = StratifiedSampleInitializer = StratifiedSelectionInitializer
|
||||
SMI = StratifiedMeanInitializer
|
||||
Random = RandomInitializer = UniformInitializer
|
||||
Zeros = ZerosInitializer
|
||||
Ones = OnesInitializer
|
||||
|
Loading…
Reference in New Issue
Block a user