diff --git a/prototorch/core/transforms.py b/prototorch/core/transforms.py index 5e00bb0..04901cd 100644 --- a/prototorch/core/transforms.py +++ b/prototorch/core/transforms.py @@ -36,7 +36,7 @@ class LinearTransform(torch.nn.Module): self._register_weights(weights) def forward(self, x): - return x @ self.weights + return x @ self._weights # Aliases