fix: matmul bug in

This commit is contained in:
Jensun Ravichandran 2021-06-21 22:47:45 +02:00
parent cfe09ec06b
commit bc9a826b7d
No known key found for this signature in database
GPG Key ID: 1BB4A641722D6B23
2 changed files with 14 additions and 2 deletions

View File

@ -36,7 +36,7 @@ class LinearTransform(torch.nn.Module):
self._register_weights(weights)
def forward(self, x):
return x @ self.weights.T
return x @ self.weights
# Aliases

View File

@ -265,13 +265,25 @@ def test_eye_transform_init_wide():
# Transforms
def test_linear_transform():
def test_linear_transform_default_eye_init():
l = pt.transforms.LinearTransform(2, 4)
actual = l.weights
desired = torch.Tensor([[1, 0, 0, 0], [0, 1, 0, 0]])
assert torch.allclose(actual, desired)
def test_linear_transform_forward():
l = pt.transforms.LinearTransform(4, 2)
actual_weights = l.weights
desired_weights = torch.Tensor([[1, 0], [0, 1], [0, 0], [0, 0]])
assert torch.allclose(actual_weights, desired_weights)
actual_outputs = l(torch.Tensor([[1.1, 2.2, 3.3, 4.4], \
[1.1, 2.2, 3.3, 4.4], \
[5.5, 6.6, 7.7, 8.8]]))
desired_outputs = torch.Tensor([[1.1, 2.2], [1.1, 2.2], [5.5, 6.6]])
assert torch.allclose(actual_outputs, desired_outputs)
def test_linear_transform_zeros_init():
l = pt.transforms.LinearTransform(
in_dim=2,