Minor tweaks
This commit is contained in:
parent
2a7394b593
commit
d8a0b2dfcc
@ -100,7 +100,7 @@ class Components(torch.nn.Module):
|
|||||||
return self._components
|
return self._components
|
||||||
|
|
||||||
def extra_repr(self):
|
def extra_repr(self):
|
||||||
return f"components.shape: {tuple(self._components.shape)}"
|
return f"(components): (shape: {tuple(self._components.shape)})"
|
||||||
|
|
||||||
|
|
||||||
class LabeledComponents(Components):
|
class LabeledComponents(Components):
|
||||||
|
@ -46,18 +46,22 @@ class ComponentsInitializer(object):
|
|||||||
|
|
||||||
|
|
||||||
class DimensionAwareInitializer(ComponentsInitializer):
|
class DimensionAwareInitializer(ComponentsInitializer):
|
||||||
def __init__(self, c_dims):
|
def __init__(self, dims):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if isinstance(c_dims, Iterable):
|
if isinstance(dims, Iterable):
|
||||||
self.components_dims = tuple(c_dims)
|
self.components_dims = tuple(dims)
|
||||||
else:
|
else:
|
||||||
self.components_dims = (c_dims, )
|
self.components_dims = (dims, )
|
||||||
|
|
||||||
|
|
||||||
class OnesInitializer(DimensionAwareInitializer):
|
class OnesInitializer(DimensionAwareInitializer):
|
||||||
|
def __init__(self, dims, scale=1.0):
|
||||||
|
super().__init__(dims)
|
||||||
|
self.scale = scale
|
||||||
|
|
||||||
def generate(self, length):
|
def generate(self, length):
|
||||||
gen_dims = (length, ) + self.components_dims
|
gen_dims = (length, ) + self.components_dims
|
||||||
return torch.ones(gen_dims)
|
return torch.ones(gen_dims) * self.scale
|
||||||
|
|
||||||
|
|
||||||
class ZerosInitializer(DimensionAwareInitializer):
|
class ZerosInitializer(DimensionAwareInitializer):
|
||||||
@ -67,8 +71,8 @@ class ZerosInitializer(DimensionAwareInitializer):
|
|||||||
|
|
||||||
|
|
||||||
class UniformInitializer(DimensionAwareInitializer):
|
class UniformInitializer(DimensionAwareInitializer):
|
||||||
def __init__(self, c_dims, min=0.0, max=1.0):
|
def __init__(self, dims, min=0.0, max=1.0):
|
||||||
super().__init__(c_dims)
|
super().__init__(dims)
|
||||||
|
|
||||||
self.min = min
|
self.min = min
|
||||||
self.max = max
|
self.max = max
|
||||||
|
Loading…
Reference in New Issue
Block a user