fix: correct typo

This commit is contained in:
Jensun Ravichandran 2022-04-04 21:50:22 +02:00
parent 71a2e74eff
commit 0d10fc7e25
No known key found for this signature in database
GPG Key ID: 7612C0CAB643D921
2 changed files with 6 additions and 6 deletions

View File

@ -5,7 +5,7 @@ from torch.nn.parameter import Parameter
from .initializers import ( from .initializers import (
AbstractLinearTransformInitializer, AbstractLinearTransformInitializer,
EyeTransformInitializer, EyeLinearTransformInitializer,
) )
@ -15,7 +15,7 @@ class LinearTransform(torch.nn.Module):
in_dim: int, in_dim: int,
out_dim: int, out_dim: int,
initializer: initializer:
AbstractLinearTransformInitializer = EyeTransformInitializer()): AbstractLinearTransformInitializer = EyeLinearTransformInitializer()):
super().__init__() super().__init__()
self.set_weights(in_dim, out_dim, initializer) self.set_weights(in_dim, out_dim, initializer)
@ -31,7 +31,7 @@ class LinearTransform(torch.nn.Module):
in_dim: int, in_dim: int,
out_dim: int, out_dim: int,
initializer: initializer:
AbstractLinearTransformInitializer = EyeTransformInitializer()): AbstractLinearTransformInitializer = EyeLinearTransformInitializer()):
weights = initializer.generate(in_dim, out_dim) weights = initializer.generate(in_dim, out_dim)
self._register_weights(weights) self._register_weights(weights)

View File

@ -245,20 +245,20 @@ def test_random_reasonings_init_channels_not_first():
# Transform initializers # Transform initializers
def test_eye_transform_init_square(): def test_eye_transform_init_square():
t = pt.initializers.EyeTransformInitializer() t = pt.initializers.EyeLinearTransformInitializer()
I = t.generate(3, 3) I = t.generate(3, 3)
assert torch.allclose(I, torch.eye(3)) assert torch.allclose(I, torch.eye(3))
def test_eye_transform_init_narrow(): def test_eye_transform_init_narrow():
t = pt.initializers.EyeTransformInitializer() t = pt.initializers.EyeLinearTransformInitializer()
actual = t.generate(3, 2) actual = t.generate(3, 2)
desired = torch.Tensor([[1, 0], [0, 1], [0, 0]]) desired = torch.Tensor([[1, 0], [0, 1], [0, 0]])
assert torch.allclose(actual, desired) assert torch.allclose(actual, desired)
def test_eye_transform_init_wide(): def test_eye_transform_init_wide():
t = pt.initializers.EyeTransformInitializer() t = pt.initializers.EyeLinearTransformInitializer()
actual = t.generate(2, 3) actual = t.generate(2, 3)
desired = torch.Tensor([[1, 0, 0], [0, 1, 0]]) desired = torch.Tensor([[1, 0, 0], [0, 1, 0]])
assert torch.allclose(actual, desired) assert torch.allclose(actual, desired)