refactor: use LinearTransform instead of torch.nn.Linear

This commit is contained in:
Jensun Ravichandran 2021-06-21 22:52:22 +02:00 committed by Alexander Engelsberger
parent 71602bf38a
commit 72af03b991

View File

@ -7,6 +7,7 @@ from ..core.competitions import wtac
from ..core.distances import lomega_distance, omega_distance, squared_euclidean_distance
from ..core.initializers import EyeTransformInitializer
from ..core.losses import glvq_loss, lvq1_loss, lvq21_loss
from ..core.transforms import LinearTransform
from ..nn.activations import get_activation
from ..nn.wrappers import LambdaLayer, LossLayer
from .abstract import ImagePrototypesMixin, SupervisedPrototypeModel
@ -208,18 +209,22 @@ class SiameseGMLVQ(SiameseGLVQ):
super().__init__(hparams, **kwargs)
# Override the backbone
self.backbone = torch.nn.Linear(self.hparams.input_dim,
self.hparams.latent_dim,
bias=False)
omega_initializer = kwargs.get("omega_initializer",
EyeTransformInitializer())
self.backbone = LinearTransform(
self.hparams.input_dim,
self.hparams.output_dim,
initializer=omega_initializer,
)
@property
def omega_matrix(self):
return self.backbone.weight.detach().cpu()
return self.backbone.weights
@property
def lambda_matrix(self):
omega = self.backbone.weight # (latent_dim, input_dim)
lam = omega.T @ omega
omega = self.backbone.weight # (input_dim, latent_dim)
lam = omega @ omega.T
return lam.detach().cpu()