refactor(api)!: merge the new api changes into dev

This commit is contained in:
Jensun Ravichandran
2021-06-20 19:00:12 +02:00
30 changed files with 368 additions and 457 deletions

View File

@@ -1,16 +1,14 @@
"""Models based on the GLVQ framework."""
import torch
from prototorch.functions.activations import get_activation
from prototorch.functions.competitions import wtac
from prototorch.functions.distances import (lomega_distance, omega_distance,
squared_euclidean_distance)
from prototorch.functions.helper import get_flat
from prototorch.functions.losses import glvq_loss, lvq1_loss, lvq21_loss
from prototorch.components import LinearMapping
from prototorch.modules import LambdaLayer, LossLayer
from torch.nn.parameter import Parameter
from ..core.competitions import wtac
from ..core.distances import lomega_distance, omega_distance, squared_euclidean_distance
from ..core.initializers import EyeTransformInitializer
from ..core.losses import glvq_loss, lvq1_loss, lvq21_loss
from ..nn.activations import get_activation
from ..nn.wrappers import LambdaLayer, LossLayer
from .abstract import ImagePrototypesMixin, SupervisedPrototypeModel
@@ -30,9 +28,6 @@ class GLVQ(SupervisedPrototypeModel):
# Loss
self.loss = LossLayer(glvq_loss)
# Prototype metrics
self.initialize_prototype_win_ratios()
def initialize_prototype_win_ratios(self):
self.register_buffer(
"prototype_win_ratios",
@@ -59,7 +54,7 @@ class GLVQ(SupervisedPrototypeModel):
def shared_step(self, batch, batch_idx, optimizer_idx=None):
x, y = batch
out = self.compute_distances(x)
plabels = self.proto_layer.component_labels
plabels = self.proto_layer.labels
mu = self.loss(out, y, prototype_labels=plabels)
batch_loss = self.transfer_layer(mu, beta=self.hparams.transfer_beta)
loss = batch_loss.sum(dim=0)
@@ -135,7 +130,7 @@ class SiameseGLVQ(GLVQ):
def compute_distances(self, x):
protos, _ = self.proto_layer()
x, protos = get_flat(x, protos)
x, protos = [arr.view(arr.size(0), -1) for arr in (x, protos)]
latent_x = self.backbone(x)
self.backbone.requires_grad_(self.both_path_gradients)
latent_protos = self.backbone(protos)
@@ -240,18 +235,14 @@ class GMLVQ(GLVQ):
super().__init__(hparams, distance_fn=distance_fn, **kwargs)
# Additional parameters
omega_initializer = kwargs.get("omega_initializer", None)
initialized_omega = kwargs.get("initialized_omega", None)
if omega_initializer is not None or initialized_omega is not None:
self.omega_layer = LinearMapping(
mapping_shape=(self.hparams.input_dim, self.hparams.latent_dim),
initializer=omega_initializer,
initialized_linearmapping=initialized_omega,
)
omega_initializer = kwargs.get("omega_initializer",
EyeTransformInitializer())
omega = omega_initializer.generate(self.hparams.input_dim,
self.hparams.latent_dim)
self.register_parameter("_omega", Parameter(omega))
self.backbone = LambdaLayer(lambda x: x @ self._omega,
name="omega matrix")
self.register_parameter("_omega", Parameter(self.omega_layer.mapping))
self.backbone = LambdaLayer(lambda x: x @ self._omega, name = "omega matrix")
@property
def omega_matrix(self):
return self._omega.detach().cpu()
@@ -264,24 +255,6 @@ class GMLVQ(GLVQ):
def extra_repr(self):
return f"(omega): (shape: {tuple(self._omega.shape)})"
def predict_latent(self, x, map_protos=True):
"""Predict `x` assuming it is already embedded in the latent space.
Only the prototypes are embedded in the latent space using the
backbone.
"""
self.eval()
with torch.no_grad():
protos, plabels = self.proto_layer()
if map_protos:
protos = self.backbone(protos)
d = squared_euclidean_distance(x, protos)
y_pred = wtac(d, plabels)
return y_pred
class LGMLVQ(GMLVQ):
"""Localized and Generalized Matrix Learning Vector Quantization."""