diff --git a/prototorch/components/initializers.py b/prototorch/components/initializers.py index 1110088..816bef7 100644 --- a/prototorch/components/initializers.py +++ b/prototorch/components/initializers.py @@ -71,15 +71,16 @@ class ZerosInitializer(DimensionAwareInitializer): class UniformInitializer(DimensionAwareInitializer): - def __init__(self, dims, minimum=0.0, maximum=1.0): + def __init__(self, dims, minimum=0.0, maximum=1.0, scale=1.0): super().__init__(dims) - self.minimum = minimum self.maximum = maximum + self.scale = scale def generate(self, length): gen_dims = (length, ) + self.components_dims - return torch.ones(gen_dims).uniform_(self.minimum, self.maximum) + return torch.ones(gen_dims).uniform_(self.minimum, + self.maximum) * self.scale class DataAwareInitializer(ComponentsInitializer):