fix: forward of LinearTransform uses undetached weights now

This commit is contained in:
Alexander Engelsberger 2022-03-29 17:06:57 +02:00
parent ed5b9b6c62
commit 46ff1c4eb1

View File

@ -36,7 +36,7 @@ class LinearTransform(torch.nn.Module):
self._register_weights(weights)
def forward(self, x):
return x @ self.weights
return x @ self._weights
# Aliases