From 72af03b99114d0ac11cf584dcd9129e8d3a7529d Mon Sep 17 00:00:00 2001 From: Jensun Ravichandran Date: Mon, 21 Jun 2021 22:52:22 +0200 Subject: [PATCH] refactor: use `LinearTransform` instead of `torch.nn.Linear` --- prototorch/models/glvq.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/prototorch/models/glvq.py b/prototorch/models/glvq.py index cc8ed57..6ab7625 100644 --- a/prototorch/models/glvq.py +++ b/prototorch/models/glvq.py @@ -7,6 +7,7 @@ from ..core.competitions import wtac from ..core.distances import lomega_distance, omega_distance, squared_euclidean_distance from ..core.initializers import EyeTransformInitializer from ..core.losses import glvq_loss, lvq1_loss, lvq21_loss +from ..core.transforms import LinearTransform from ..nn.activations import get_activation from ..nn.wrappers import LambdaLayer, LossLayer from .abstract import ImagePrototypesMixin, SupervisedPrototypeModel @@ -208,18 +209,22 @@ class SiameseGMLVQ(SiameseGLVQ): super().__init__(hparams, **kwargs) # Override the backbone - self.backbone = torch.nn.Linear(self.hparams.input_dim, - self.hparams.latent_dim, - bias=False) + omega_initializer = kwargs.get("omega_initializer", + EyeTransformInitializer()) + self.backbone = LinearTransform( + self.hparams.input_dim, + self.hparams.output_dim, + initializer=omega_initializer, + ) @property def omega_matrix(self): - return self.backbone.weight.detach().cpu() + return self.backbone.weights @property def lambda_matrix(self): - omega = self.backbone.weight # (latent_dim, input_dim) - lam = omega.T @ omega + omega = self.backbone.weight # (input_dim, latent_dim) + lam = omega @ omega.T return lam.detach().cpu()