Minor tweaks
This commit is contained in:
parent
2a7394b593
commit
d8a0b2dfcc
@ -100,7 +100,7 @@ class Components(torch.nn.Module):
|
||||
return self._components
|
||||
|
||||
def extra_repr(self):
|
||||
return f"components.shape: {tuple(self._components.shape)}"
|
||||
return f"(components): (shape: {tuple(self._components.shape)})"
|
||||
|
||||
|
||||
class LabeledComponents(Components):
|
||||
|
@ -46,18 +46,22 @@ class ComponentsInitializer(object):
|
||||
|
||||
|
||||
class DimensionAwareInitializer(ComponentsInitializer):
|
||||
def __init__(self, c_dims):
|
||||
def __init__(self, dims):
|
||||
super().__init__()
|
||||
if isinstance(c_dims, Iterable):
|
||||
self.components_dims = tuple(c_dims)
|
||||
if isinstance(dims, Iterable):
|
||||
self.components_dims = tuple(dims)
|
||||
else:
|
||||
self.components_dims = (c_dims, )
|
||||
self.components_dims = (dims, )
|
||||
|
||||
|
||||
class OnesInitializer(DimensionAwareInitializer):
|
||||
def __init__(self, dims, scale=1.0):
|
||||
super().__init__(dims)
|
||||
self.scale = scale
|
||||
|
||||
def generate(self, length):
|
||||
gen_dims = (length, ) + self.components_dims
|
||||
return torch.ones(gen_dims)
|
||||
return torch.ones(gen_dims) * self.scale
|
||||
|
||||
|
||||
class ZerosInitializer(DimensionAwareInitializer):
|
||||
@ -67,8 +71,8 @@ class ZerosInitializer(DimensionAwareInitializer):
|
||||
|
||||
|
||||
class UniformInitializer(DimensionAwareInitializer):
|
||||
def __init__(self, c_dims, min=0.0, max=1.0):
|
||||
super().__init__(c_dims)
|
||||
def __init__(self, dims, min=0.0, max=1.0):
|
||||
super().__init__(dims)
|
||||
|
||||
self.min = min
|
||||
self.max = max
|
||||
|
Loading…
Reference in New Issue
Block a user