diff --git a/prototorch/core/transforms.py b/prototorch/core/transforms.py index 04901cd..35ea85b 100644 --- a/prototorch/core/transforms.py +++ b/prototorch/core/transforms.py @@ -38,6 +38,9 @@ class LinearTransform(torch.nn.Module): def forward(self, x): return x @ self._weights + def extra_repr(self): + return f"weights: (shape: {tuple(self._weights.shape)})" + # Aliases Omega = LinearTransform