[BUGFIX] Add missing file
This commit is contained in:
		
							
								
								
									
										44
									
								
								prototorch/core/transforms.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										44
									
								
								prototorch/core/transforms.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,44 @@
 | 
				
			|||||||
 | 
					"""ProtoTorch transforms"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					from torch.nn.parameter import Parameter
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from .initializers import (
 | 
				
			||||||
 | 
					    AbstractLinearTransformInitializer,
 | 
				
			||||||
 | 
					    EyeTransformInitializer,
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class LinearTransform(torch.nn.Module):
 | 
				
			||||||
 | 
					    def __init__(
 | 
				
			||||||
 | 
					            self,
 | 
				
			||||||
 | 
					            in_dim: int,
 | 
				
			||||||
 | 
					            out_dim: int,
 | 
				
			||||||
 | 
					            initializer:
 | 
				
			||||||
 | 
					        AbstractLinearTransformInitializer = EyeTransformInitializer(),
 | 
				
			||||||
 | 
					            **kwargs):
 | 
				
			||||||
 | 
					        super().__init__(**kwargs)
 | 
				
			||||||
 | 
					        self.set_weights(in_dim, out_dim, initializer)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @property
 | 
				
			||||||
 | 
					    def weights(self):
 | 
				
			||||||
 | 
					        return self._weights.detach().cpu()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def _register_weights(self, weights):
 | 
				
			||||||
 | 
					        self.register_parameter("_weights", Parameter(weights))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def set_weights(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        in_dim: int,
 | 
				
			||||||
 | 
					        out_dim: int,
 | 
				
			||||||
 | 
					        initializer:
 | 
				
			||||||
 | 
					        AbstractLinearTransformInitializer = EyeTransformInitializer()):
 | 
				
			||||||
 | 
					        weights = initializer.generate(in_dim, out_dim)
 | 
				
			||||||
 | 
					        self._register_weights(weights)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def forward(self, x):
 | 
				
			||||||
 | 
					        return x @ self.weights.T
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Aliases
 | 
				
			||||||
 | 
					Omega = LinearTransform
 | 
				
			||||||
		Reference in New Issue
	
	Block a user