diff --git a/prototorch/components/components.py b/prototorch/components/components.py index 25686e1..aa7bb70 100644 --- a/prototorch/components/components.py +++ b/prototorch/components/components.py @@ -100,7 +100,7 @@ class Components(torch.nn.Module): return self._components def extra_repr(self): - return f"components.shape: {tuple(self._components.shape)}" + return f"(components): (shape: {tuple(self._components.shape)})" class LabeledComponents(Components): diff --git a/prototorch/components/initializers.py b/prototorch/components/initializers.py index 9116f15..5229777 100644 --- a/prototorch/components/initializers.py +++ b/prototorch/components/initializers.py @@ -46,18 +46,22 @@ class ComponentsInitializer(object): class DimensionAwareInitializer(ComponentsInitializer): - def __init__(self, c_dims): + def __init__(self, dims): super().__init__() - if isinstance(c_dims, Iterable): - self.components_dims = tuple(c_dims) + if isinstance(dims, Iterable): + self.components_dims = tuple(dims) else: - self.components_dims = (c_dims, ) + self.components_dims = (dims, ) class OnesInitializer(DimensionAwareInitializer): + def __init__(self, dims, scale=1.0): + super().__init__(dims) + self.scale = scale + def generate(self, length): gen_dims = (length, ) + self.components_dims - return torch.ones(gen_dims) + return torch.ones(gen_dims) * self.scale class ZerosInitializer(DimensionAwareInitializer): @@ -67,8 +71,8 @@ class ZerosInitializer(DimensionAwareInitializer): class UniformInitializer(DimensionAwareInitializer): - def __init__(self, c_dims, min=0.0, max=1.0): - super().__init__(c_dims) + def __init__(self, dims, min=0.0, max=1.0): + super().__init__(dims) self.min = min self.max = max