fix: matmul bug in
This commit is contained in:
parent
cfe09ec06b
commit
bc9a826b7d
@ -36,7 +36,7 @@ class LinearTransform(torch.nn.Module):
|
|||||||
self._register_weights(weights)
|
self._register_weights(weights)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return x @ self.weights.T
|
return x @ self.weights
|
||||||
|
|
||||||
|
|
||||||
# Aliases
|
# Aliases
|
||||||
|
@ -265,13 +265,25 @@ def test_eye_transform_init_wide():
|
|||||||
|
|
||||||
|
|
||||||
# Transforms
|
# Transforms
|
||||||
def test_linear_transform():
|
def test_linear_transform_default_eye_init():
|
||||||
l = pt.transforms.LinearTransform(2, 4)
|
l = pt.transforms.LinearTransform(2, 4)
|
||||||
actual = l.weights
|
actual = l.weights
|
||||||
desired = torch.Tensor([[1, 0, 0, 0], [0, 1, 0, 0]])
|
desired = torch.Tensor([[1, 0, 0, 0], [0, 1, 0, 0]])
|
||||||
assert torch.allclose(actual, desired)
|
assert torch.allclose(actual, desired)
|
||||||
|
|
||||||
|
|
||||||
|
def test_linear_transform_forward():
|
||||||
|
l = pt.transforms.LinearTransform(4, 2)
|
||||||
|
actual_weights = l.weights
|
||||||
|
desired_weights = torch.Tensor([[1, 0], [0, 1], [0, 0], [0, 0]])
|
||||||
|
assert torch.allclose(actual_weights, desired_weights)
|
||||||
|
actual_outputs = l(torch.Tensor([[1.1, 2.2, 3.3, 4.4], \
|
||||||
|
[1.1, 2.2, 3.3, 4.4], \
|
||||||
|
[5.5, 6.6, 7.7, 8.8]]))
|
||||||
|
desired_outputs = torch.Tensor([[1.1, 2.2], [1.1, 2.2], [5.5, 6.6]])
|
||||||
|
assert torch.allclose(actual_outputs, desired_outputs)
|
||||||
|
|
||||||
|
|
||||||
def test_linear_transform_zeros_init():
|
def test_linear_transform_zeros_init():
|
||||||
l = pt.transforms.LinearTransform(
|
l = pt.transforms.LinearTransform(
|
||||||
in_dim=2,
|
in_dim=2,
|
||||||
|
Loading…
Reference in New Issue
Block a user