fix: correct typo

This commit is contained in:
Jensun Ravichandran 2022-04-04 21:50:22 +02:00
parent 71a2e74eff
commit 0d10fc7e25
No known key found for this signature in database
GPG Key ID: 7612C0CAB643D921
2 changed files with 6 additions and 6 deletions

View File

@ -5,7 +5,7 @@ from torch.nn.parameter import Parameter
from .initializers import (
AbstractLinearTransformInitializer,
EyeTransformInitializer,
EyeLinearTransformInitializer,
)
@ -15,7 +15,7 @@ class LinearTransform(torch.nn.Module):
in_dim: int,
out_dim: int,
initializer:
AbstractLinearTransformInitializer = EyeTransformInitializer()):
AbstractLinearTransformInitializer = EyeLinearTransformInitializer()):
super().__init__()
self.set_weights(in_dim, out_dim, initializer)
@ -31,7 +31,7 @@ class LinearTransform(torch.nn.Module):
in_dim: int,
out_dim: int,
initializer:
AbstractLinearTransformInitializer = EyeTransformInitializer()):
AbstractLinearTransformInitializer = EyeLinearTransformInitializer()):
weights = initializer.generate(in_dim, out_dim)
self._register_weights(weights)

View File

@ -245,20 +245,20 @@ def test_random_reasonings_init_channels_not_first():
# Transform initializers
def test_eye_transform_init_square():
t = pt.initializers.EyeTransformInitializer()
t = pt.initializers.EyeLinearTransformInitializer()
I = t.generate(3, 3)
assert torch.allclose(I, torch.eye(3))
def test_eye_transform_init_narrow():
t = pt.initializers.EyeTransformInitializer()
t = pt.initializers.EyeLinearTransformInitializer()
actual = t.generate(3, 2)
desired = torch.Tensor([[1, 0], [0, 1], [0, 0]])
assert torch.allclose(actual, desired)
def test_eye_transform_init_wide():
t = pt.initializers.EyeTransformInitializer()
t = pt.initializers.EyeLinearTransformInitializer()
actual = t.generate(2, 3)
desired = torch.Tensor([[1, 0, 0], [0, 1, 0]])
assert torch.allclose(actual, desired)