Minor tweaks

This commit is contained in:
Jensun Ravichandran 2021-06-01 23:28:01 +02:00
parent 2a7394b593
commit d8a0b2dfcc
2 changed files with 12 additions and 8 deletions

View File

@ -100,7 +100,7 @@ class Components(torch.nn.Module):
return self._components return self._components
def extra_repr(self): def extra_repr(self):
return f"components.shape: {tuple(self._components.shape)}" return f"(components): (shape: {tuple(self._components.shape)})"
class LabeledComponents(Components): class LabeledComponents(Components):

View File

@ -46,18 +46,22 @@ class ComponentsInitializer(object):
class DimensionAwareInitializer(ComponentsInitializer): class DimensionAwareInitializer(ComponentsInitializer):
def __init__(self, c_dims): def __init__(self, dims):
super().__init__() super().__init__()
if isinstance(c_dims, Iterable): if isinstance(dims, Iterable):
self.components_dims = tuple(c_dims) self.components_dims = tuple(dims)
else: else:
self.components_dims = (c_dims, ) self.components_dims = (dims, )
class OnesInitializer(DimensionAwareInitializer): class OnesInitializer(DimensionAwareInitializer):
def __init__(self, dims, scale=1.0):
super().__init__(dims)
self.scale = scale
def generate(self, length): def generate(self, length):
gen_dims = (length, ) + self.components_dims gen_dims = (length, ) + self.components_dims
return torch.ones(gen_dims) return torch.ones(gen_dims) * self.scale
class ZerosInitializer(DimensionAwareInitializer): class ZerosInitializer(DimensionAwareInitializer):
@ -67,8 +71,8 @@ class ZerosInitializer(DimensionAwareInitializer):
class UniformInitializer(DimensionAwareInitializer): class UniformInitializer(DimensionAwareInitializer):
def __init__(self, c_dims, min=0.0, max=1.0): def __init__(self, dims, min=0.0, max=1.0):
super().__init__(c_dims) super().__init__(dims)
self.min = min self.min = min
self.max = max self.max = max