fix: matmul bug in
This commit is contained in:
parent
cfe09ec06b
commit
bc9a826b7d
@ -36,7 +36,7 @@ class LinearTransform(torch.nn.Module):
|
||||
self._register_weights(weights)
|
||||
|
||||
def forward(self, x):
|
||||
return x @ self.weights.T
|
||||
return x @ self.weights
|
||||
|
||||
|
||||
# Aliases
|
||||
|
@ -265,13 +265,25 @@ def test_eye_transform_init_wide():
|
||||
|
||||
|
||||
# Transforms
|
||||
def test_linear_transform():
|
||||
def test_linear_transform_default_eye_init():
|
||||
l = pt.transforms.LinearTransform(2, 4)
|
||||
actual = l.weights
|
||||
desired = torch.Tensor([[1, 0, 0, 0], [0, 1, 0, 0]])
|
||||
assert torch.allclose(actual, desired)
|
||||
|
||||
|
||||
def test_linear_transform_forward():
|
||||
l = pt.transforms.LinearTransform(4, 2)
|
||||
actual_weights = l.weights
|
||||
desired_weights = torch.Tensor([[1, 0], [0, 1], [0, 0], [0, 0]])
|
||||
assert torch.allclose(actual_weights, desired_weights)
|
||||
actual_outputs = l(torch.Tensor([[1.1, 2.2, 3.3, 4.4], \
|
||||
[1.1, 2.2, 3.3, 4.4], \
|
||||
[5.5, 6.6, 7.7, 8.8]]))
|
||||
desired_outputs = torch.Tensor([[1.1, 2.2], [1.1, 2.2], [5.5, 6.6]])
|
||||
assert torch.allclose(actual_outputs, desired_outputs)
|
||||
|
||||
|
||||
def test_linear_transform_zeros_init():
|
||||
l = pt.transforms.LinearTransform(
|
||||
in_dim=2,
|
||||
|
Loading…
Reference in New Issue
Block a user