feat: add repr for LinearTransform

This commit is contained in:
Jensun Ravichandran 2022-04-01 10:13:25 +02:00
parent 46ff1c4eb1
commit 85f75bb28c
No known key found for this signature in database
GPG Key ID: 4E9348239810B51F

View File

@ -38,6 +38,9 @@ class LinearTransform(torch.nn.Module):
def forward(self, x):
return x @ self._weights
def extra_repr(self):
return f"weights: (shape: {tuple(self._weights.shape)})"
# Aliases
Omega = LinearTransform