fix: forward of LinearTransform uses undetached weights now
This commit is contained in:
		@@ -36,7 +36,7 @@ class LinearTransform(torch.nn.Module):
 | 
				
			|||||||
        self._register_weights(weights)
 | 
					        self._register_weights(weights)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def forward(self, x):
 | 
					    def forward(self, x):
 | 
				
			||||||
        return x @ self.weights
 | 
					        return x @ self._weights
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# Aliases
 | 
					# Aliases
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user