diff --git a/prototorch/core/transforms.py b/prototorch/core/transforms.py index 35ea85b..4065a68 100644 --- a/prototorch/core/transforms.py +++ b/prototorch/core/transforms.py @@ -5,7 +5,7 @@ from torch.nn.parameter import Parameter from .initializers import ( AbstractLinearTransformInitializer, - EyeTransformInitializer, + EyeLinearTransformInitializer, ) @@ -15,7 +15,7 @@ class LinearTransform(torch.nn.Module): in_dim: int, out_dim: int, initializer: - AbstractLinearTransformInitializer = EyeTransformInitializer()): + AbstractLinearTransformInitializer = EyeLinearTransformInitializer()): super().__init__() self.set_weights(in_dim, out_dim, initializer) @@ -31,7 +31,7 @@ class LinearTransform(torch.nn.Module): in_dim: int, out_dim: int, initializer: - AbstractLinearTransformInitializer = EyeTransformInitializer()): + AbstractLinearTransformInitializer = EyeLinearTransformInitializer()): weights = initializer.generate(in_dim, out_dim) self._register_weights(weights) diff --git a/tests/test_core.py b/tests/test_core.py index 816cac7..0b2a220 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -245,20 +245,20 @@ def test_random_reasonings_init_channels_not_first(): # Transform initializers def test_eye_transform_init_square(): - t = pt.initializers.EyeTransformInitializer() + t = pt.initializers.EyeLinearTransformInitializer() I = t.generate(3, 3) assert torch.allclose(I, torch.eye(3)) def test_eye_transform_init_narrow(): - t = pt.initializers.EyeTransformInitializer() + t = pt.initializers.EyeLinearTransformInitializer() actual = t.generate(3, 2) desired = torch.Tensor([[1, 0], [0, 1], [0, 0]]) assert torch.allclose(actual, desired) def test_eye_transform_init_wide(): - t = pt.initializers.EyeTransformInitializer() + t = pt.initializers.EyeLinearTransformInitializer() actual = t.generate(2, 3) desired = torch.Tensor([[1, 0, 0], [0, 1, 0]]) assert torch.allclose(actual, desired)