fix: matmul bug in

This commit is contained in:
Jensun Ravichandran
2021-06-21 22:47:45 +02:00
parent cfe09ec06b
commit bc9a826b7d
2 changed files with 14 additions and 2 deletions

View File

@@ -36,7 +36,7 @@ class LinearTransform(torch.nn.Module):
self._register_weights(weights)
def forward(self, x):
return x @ self.weights.T
return x @ self.weights
# Aliases