fix: matmul bug in

This commit is contained in:
Jensun Ravichandran
2021-06-21 22:47:45 +02:00
parent cfe09ec06b
commit bc9a826b7d
2 changed files with 14 additions and 2 deletions

View File

@@ -265,13 +265,25 @@ def test_eye_transform_init_wide():
# Transforms
def test_linear_transform():
def test_linear_transform_default_eye_init():
l = pt.transforms.LinearTransform(2, 4)
actual = l.weights
desired = torch.Tensor([[1, 0, 0, 0], [0, 1, 0, 0]])
assert torch.allclose(actual, desired)
def test_linear_transform_forward():
l = pt.transforms.LinearTransform(4, 2)
actual_weights = l.weights
desired_weights = torch.Tensor([[1, 0], [0, 1], [0, 0], [0, 0]])
assert torch.allclose(actual_weights, desired_weights)
actual_outputs = l(torch.Tensor([[1.1, 2.2, 3.3, 4.4], \
[1.1, 2.2, 3.3, 4.4], \
[5.5, 6.6, 7.7, 8.8]]))
desired_outputs = torch.Tensor([[1.1, 2.2], [1.1, 2.2], [5.5, 6.6]])
assert torch.allclose(actual_outputs, desired_outputs)
def test_linear_transform_zeros_init():
l = pt.transforms.LinearTransform(
in_dim=2,