From 85f75bb28cf626f08b0f2d4b18ac05de1e958a2e Mon Sep 17 00:00:00 2001 From: Jensun Ravichandran Date: Fri, 1 Apr 2022 10:13:25 +0200 Subject: [PATCH] feat: add repr for `LinearTransform` --- prototorch/core/transforms.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/prototorch/core/transforms.py b/prototorch/core/transforms.py index 04901cd..35ea85b 100644 --- a/prototorch/core/transforms.py +++ b/prototorch/core/transforms.py @@ -38,6 +38,9 @@ class LinearTransform(torch.nn.Module): def forward(self, x): return x @ self._weights + def extra_repr(self): + return f"weights: (shape: {tuple(self._weights.shape)})" + # Aliases Omega = LinearTransform