diff --git a/prototorch/core/transforms.py b/prototorch/core/transforms.py index efac17c..5e00bb0 100644 --- a/prototorch/core/transforms.py +++ b/prototorch/core/transforms.py @@ -36,7 +36,7 @@ class LinearTransform(torch.nn.Module): self._register_weights(weights) def forward(self, x): - return x @ self.weights.T + return x @ self.weights # Aliases diff --git a/tests/test_core.py b/tests/test_core.py index d007f9b..816cac7 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -265,13 +265,25 @@ def test_eye_transform_init_wide(): # Transforms -def test_linear_transform(): +def test_linear_transform_default_eye_init(): l = pt.transforms.LinearTransform(2, 4) actual = l.weights desired = torch.Tensor([[1, 0, 0, 0], [0, 1, 0, 0]]) assert torch.allclose(actual, desired) +def test_linear_transform_forward(): + l = pt.transforms.LinearTransform(4, 2) + actual_weights = l.weights + desired_weights = torch.Tensor([[1, 0], [0, 1], [0, 0], [0, 0]]) + assert torch.allclose(actual_weights, desired_weights) + actual_outputs = l(torch.Tensor([[1.1, 2.2, 3.3, 4.4], \ + [1.1, 2.2, 3.3, 4.4], \ + [5.5, 6.6, 7.7, 8.8]])) + desired_outputs = torch.Tensor([[1.1, 2.2], [1.1, 2.2], [5.5, 6.6]]) + assert torch.allclose(actual_outputs, desired_outputs) + + def test_linear_transform_zeros_init(): l = pt.transforms.LinearTransform( in_dim=2,