refactor: use LinearTransform instead of torch.nn.Linear
				
					
				
			This commit is contained in:
		
				
					committed by
					
						
						Alexander Engelsberger
					
				
			
			
				
	
			
			
			
						parent
						
							71602bf38a
						
					
				
				
					commit
					72af03b991
				
			@@ -7,6 +7,7 @@ from ..core.competitions import wtac
 | 
				
			|||||||
from ..core.distances import lomega_distance, omega_distance, squared_euclidean_distance
 | 
					from ..core.distances import lomega_distance, omega_distance, squared_euclidean_distance
 | 
				
			||||||
from ..core.initializers import EyeTransformInitializer
 | 
					from ..core.initializers import EyeTransformInitializer
 | 
				
			||||||
from ..core.losses import glvq_loss, lvq1_loss, lvq21_loss
 | 
					from ..core.losses import glvq_loss, lvq1_loss, lvq21_loss
 | 
				
			||||||
 | 
					from ..core.transforms import LinearTransform
 | 
				
			||||||
from ..nn.activations import get_activation
 | 
					from ..nn.activations import get_activation
 | 
				
			||||||
from ..nn.wrappers import LambdaLayer, LossLayer
 | 
					from ..nn.wrappers import LambdaLayer, LossLayer
 | 
				
			||||||
from .abstract import ImagePrototypesMixin, SupervisedPrototypeModel
 | 
					from .abstract import ImagePrototypesMixin, SupervisedPrototypeModel
 | 
				
			||||||
@@ -208,18 +209,22 @@ class SiameseGMLVQ(SiameseGLVQ):
 | 
				
			|||||||
        super().__init__(hparams, **kwargs)
 | 
					        super().__init__(hparams, **kwargs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # Override the backbone
 | 
					        # Override the backbone
 | 
				
			||||||
        self.backbone = torch.nn.Linear(self.hparams.input_dim,
 | 
					        omega_initializer = kwargs.get("omega_initializer",
 | 
				
			||||||
                                        self.hparams.latent_dim,
 | 
					                                       EyeTransformInitializer())
 | 
				
			||||||
                                        bias=False)
 | 
					        self.backbone = LinearTransform(
 | 
				
			||||||
 | 
					            self.hparams.input_dim,
 | 
				
			||||||
 | 
					            self.hparams.output_dim,
 | 
				
			||||||
 | 
					            initializer=omega_initializer,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @property
 | 
					    @property
 | 
				
			||||||
    def omega_matrix(self):
 | 
					    def omega_matrix(self):
 | 
				
			||||||
        return self.backbone.weight.detach().cpu()
 | 
					        return self.backbone.weights
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @property
 | 
					    @property
 | 
				
			||||||
    def lambda_matrix(self):
 | 
					    def lambda_matrix(self):
 | 
				
			||||||
        omega = self.backbone.weight  # (latent_dim, input_dim)
 | 
					        omega = self.backbone.weight  # (input_dim, latent_dim)
 | 
				
			||||||
        lam = omega.T @ omega
 | 
					        lam = omega @ omega.T
 | 
				
			||||||
        return lam.detach().cpu()
 | 
					        return lam.detach().cpu()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user