refactor: use LinearTransform
instead of torch.nn.Linear
This commit is contained in:
parent
71602bf38a
commit
72af03b991
@ -7,6 +7,7 @@ from ..core.competitions import wtac
|
|||||||
from ..core.distances import lomega_distance, omega_distance, squared_euclidean_distance
|
from ..core.distances import lomega_distance, omega_distance, squared_euclidean_distance
|
||||||
from ..core.initializers import EyeTransformInitializer
|
from ..core.initializers import EyeTransformInitializer
|
||||||
from ..core.losses import glvq_loss, lvq1_loss, lvq21_loss
|
from ..core.losses import glvq_loss, lvq1_loss, lvq21_loss
|
||||||
|
from ..core.transforms import LinearTransform
|
||||||
from ..nn.activations import get_activation
|
from ..nn.activations import get_activation
|
||||||
from ..nn.wrappers import LambdaLayer, LossLayer
|
from ..nn.wrappers import LambdaLayer, LossLayer
|
||||||
from .abstract import ImagePrototypesMixin, SupervisedPrototypeModel
|
from .abstract import ImagePrototypesMixin, SupervisedPrototypeModel
|
||||||
@ -208,18 +209,22 @@ class SiameseGMLVQ(SiameseGLVQ):
|
|||||||
super().__init__(hparams, **kwargs)
|
super().__init__(hparams, **kwargs)
|
||||||
|
|
||||||
# Override the backbone
|
# Override the backbone
|
||||||
self.backbone = torch.nn.Linear(self.hparams.input_dim,
|
omega_initializer = kwargs.get("omega_initializer",
|
||||||
self.hparams.latent_dim,
|
EyeTransformInitializer())
|
||||||
bias=False)
|
self.backbone = LinearTransform(
|
||||||
|
self.hparams.input_dim,
|
||||||
|
self.hparams.output_dim,
|
||||||
|
initializer=omega_initializer,
|
||||||
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def omega_matrix(self):
|
def omega_matrix(self):
|
||||||
return self.backbone.weight.detach().cpu()
|
return self.backbone.weights
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def lambda_matrix(self):
|
def lambda_matrix(self):
|
||||||
omega = self.backbone.weight # (latent_dim, input_dim)
|
omega = self.backbone.weight # (input_dim, latent_dim)
|
||||||
lam = omega.T @ omega
|
lam = omega @ omega.T
|
||||||
return lam.detach().cpu()
|
return lam.detach().cpu()
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user