fix: correct typo
This commit is contained in:
parent
71a2e74eff
commit
0d10fc7e25
@ -5,7 +5,7 @@ from torch.nn.parameter import Parameter
|
|||||||
|
|
||||||
from .initializers import (
|
from .initializers import (
|
||||||
AbstractLinearTransformInitializer,
|
AbstractLinearTransformInitializer,
|
||||||
EyeTransformInitializer,
|
EyeLinearTransformInitializer,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -15,7 +15,7 @@ class LinearTransform(torch.nn.Module):
|
|||||||
in_dim: int,
|
in_dim: int,
|
||||||
out_dim: int,
|
out_dim: int,
|
||||||
initializer:
|
initializer:
|
||||||
AbstractLinearTransformInitializer = EyeTransformInitializer()):
|
AbstractLinearTransformInitializer = EyeLinearTransformInitializer()):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.set_weights(in_dim, out_dim, initializer)
|
self.set_weights(in_dim, out_dim, initializer)
|
||||||
|
|
||||||
@ -31,7 +31,7 @@ class LinearTransform(torch.nn.Module):
|
|||||||
in_dim: int,
|
in_dim: int,
|
||||||
out_dim: int,
|
out_dim: int,
|
||||||
initializer:
|
initializer:
|
||||||
AbstractLinearTransformInitializer = EyeTransformInitializer()):
|
AbstractLinearTransformInitializer = EyeLinearTransformInitializer()):
|
||||||
weights = initializer.generate(in_dim, out_dim)
|
weights = initializer.generate(in_dim, out_dim)
|
||||||
self._register_weights(weights)
|
self._register_weights(weights)
|
||||||
|
|
||||||
|
@ -245,20 +245,20 @@ def test_random_reasonings_init_channels_not_first():
|
|||||||
|
|
||||||
# Transform initializers
|
# Transform initializers
|
||||||
def test_eye_transform_init_square():
|
def test_eye_transform_init_square():
|
||||||
t = pt.initializers.EyeTransformInitializer()
|
t = pt.initializers.EyeLinearTransformInitializer()
|
||||||
I = t.generate(3, 3)
|
I = t.generate(3, 3)
|
||||||
assert torch.allclose(I, torch.eye(3))
|
assert torch.allclose(I, torch.eye(3))
|
||||||
|
|
||||||
|
|
||||||
def test_eye_transform_init_narrow():
|
def test_eye_transform_init_narrow():
|
||||||
t = pt.initializers.EyeTransformInitializer()
|
t = pt.initializers.EyeLinearTransformInitializer()
|
||||||
actual = t.generate(3, 2)
|
actual = t.generate(3, 2)
|
||||||
desired = torch.Tensor([[1, 0], [0, 1], [0, 0]])
|
desired = torch.Tensor([[1, 0], [0, 1], [0, 0]])
|
||||||
assert torch.allclose(actual, desired)
|
assert torch.allclose(actual, desired)
|
||||||
|
|
||||||
|
|
||||||
def test_eye_transform_init_wide():
|
def test_eye_transform_init_wide():
|
||||||
t = pt.initializers.EyeTransformInitializer()
|
t = pt.initializers.EyeLinearTransformInitializer()
|
||||||
actual = t.generate(2, 3)
|
actual = t.generate(2, 3)
|
||||||
desired = torch.Tensor([[1, 0, 0], [0, 1, 0]])
|
desired = torch.Tensor([[1, 0, 0], [0, 1, 0]])
|
||||||
assert torch.allclose(actual, desired)
|
assert torch.allclose(actual, desired)
|
||||||
|
Loading…
Reference in New Issue
Block a user