From 736d9a6349db29b69d5f17faf151270f20ccf7a6 Mon Sep 17 00:00:00 2001 From: Jensun Ravichandran Date: Tue, 18 May 2021 19:37:25 +0200 Subject: [PATCH] Rename PositionAwareInitializer to DataAwareInitializer Also, add the aliases `Zeros` and `Ones`. --- prototorch/components/initializers.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/prototorch/components/initializers.py b/prototorch/components/initializers.py index 13b9758..d8b6529 100644 --- a/prototorch/components/initializers.py +++ b/prototorch/components/initializers.py @@ -62,19 +62,19 @@ class UniformInitializer(DimensionAwareInitializer): return torch.ones(gen_dims).uniform_(self.min, self.max) -class PositionAwareInitializer(ComponentsInitializer): - def __init__(self, positions): +class DataAwareInitializer(ComponentsInitializer): + def __init__(self, data): super().__init__() - self.data = positions + self.data = data -class SelectionInitializer(PositionAwareInitializer): +class SelectionInitializer(DataAwareInitializer): def generate(self, length): indices = torch.LongTensor(length).random_(0, len(self.data)) return self.data[indices] -class MeanInitializer(PositionAwareInitializer): +class MeanInitializer(DataAwareInitializer): def generate(self, length): mean = torch.mean(self.data, dim=0) repeat_dim = [length] + [1] * len(mean.shape) @@ -205,3 +205,5 @@ class ZeroReasoningsInitializer(ReasoningsInitializer): SSI = StratifiedSampleInitializer = StratifiedSelectionInitializer SMI = StratifiedMeanInitializer Random = RandomInitializer = UniformInitializer +Zeros = ZerosInitializer +Ones = OnesInitializer