GMLMLVQ: allow for 2 or more omega layers

This commit is contained in:
julius 2023-11-07 16:44:13 +01:00
parent 1786031b4e
commit c6f718a1d4
Signed by untrusted user who does not match committer: julius
GPG Key ID: 8AA3791362A8084A

View File

@ -1,14 +1,17 @@
"""Models based on the GLVQ framework.""" """Models based on the GLVQ framework."""
from typing import LiteralString
import torch import torch
from numpy.typing import NDArray
from prototorch.core.competitions import wtac from prototorch.core.competitions import wtac
from prototorch.core.distances import ( from prototorch.core.distances import (
ML_omega_distance,
lomega_distance, lomega_distance,
omega_distance, omega_distance,
ML_omega_distance,
squared_euclidean_distance, squared_euclidean_distance,
) )
from prototorch.core.initializers import (EyeLinearTransformInitializer, LLTI) from prototorch.core.initializers import LLTI, EyeLinearTransformInitializer
from prototorch.core.losses import ( from prototorch.core.losses import (
GLVQLoss, GLVQLoss,
lvq1_loss, lvq1_loss,
@ -16,7 +19,7 @@ from prototorch.core.losses import (
) )
from prototorch.core.transforms import LinearTransform from prototorch.core.transforms import LinearTransform
from prototorch.nn.wrappers import LambdaLayer, LossLayer from prototorch.nn.wrappers import LambdaLayer, LossLayer
from torch.nn.parameter import Parameter from torch.nn import Parameter, ParameterList
from .abstract import ImagePrototypesMixin, SupervisedPrototypeModel from .abstract import ImagePrototypesMixin, SupervisedPrototypeModel
from .extras import ltangent_distance, orthogonalization from .extras import ltangent_distance, orthogonalization
@ -46,26 +49,28 @@ class GLVQ(SupervisedPrototypeModel):
def initialize_prototype_win_ratios(self): def initialize_prototype_win_ratios(self):
self.register_buffer( self.register_buffer(
"prototype_win_ratios", "prototype_win_ratios", torch.zeros(self.num_prototypes, device=self.device)
torch.zeros(self.num_prototypes, device=self.device)) )
def on_train_epoch_start(self): def on_train_epoch_start(self):
self.initialize_prototype_win_ratios() self.initialize_prototype_win_ratios()
def log_prototype_win_ratios(self, distances): def log_prototype_win_ratios(self, distances):
batch_size = len(distances) batch_size = len(distances)
prototype_wc = torch.zeros(self.num_prototypes, prototype_wc = torch.zeros(
dtype=torch.long, self.num_prototypes, dtype=torch.long, device=self.device
device=self.device) )
wi, wc = torch.unique(distances.min(dim=-1).indices, wi, wc = torch.unique(
sorted=True, distances.min(dim=-1).indices, sorted=True, return_counts=True
return_counts=True) )
prototype_wc[wi] = wc prototype_wc[wi] = wc
prototype_wr = prototype_wc / batch_size prototype_wr = prototype_wc / batch_size
self.prototype_win_ratios = torch.vstack([ self.prototype_win_ratios = torch.vstack(
[
self.prototype_win_ratios, self.prototype_win_ratios,
prototype_wr, prototype_wr,
]) ]
)
def shared_step(self, batch, batch_idx): def shared_step(self, batch, batch_idx):
x, y = batch x, y = batch
@ -110,11 +115,9 @@ class SiameseGLVQ(GLVQ):
""" """
def __init__(self, def __init__(
hparams, self, hparams, backbone=torch.nn.Identity(), both_path_gradients=False, **kwargs
backbone=torch.nn.Identity(), ):
both_path_gradients=False,
**kwargs):
distance_fn = kwargs.pop("distance_fn", squared_euclidean_distance) distance_fn = kwargs.pop("distance_fn", squared_euclidean_distance)
super().__init__(hparams, distance_fn=distance_fn, **kwargs) super().__init__(hparams, distance_fn=distance_fn, **kwargs)
self.backbone = backbone self.backbone = backbone
@ -176,6 +179,7 @@ class GRLVQ(SiameseGLVQ):
TODO Make a RelevanceLayer. `bb_lr` is ignored otherwise. TODO Make a RelevanceLayer. `bb_lr` is ignored otherwise.
""" """
_relevances: torch.Tensor _relevances: torch.Tensor
def __init__(self, hparams, **kwargs): def __init__(self, hparams, **kwargs):
@ -186,8 +190,7 @@ class GRLVQ(SiameseGLVQ):
self.register_parameter("_relevances", Parameter(relevances)) self.register_parameter("_relevances", Parameter(relevances))
# Override the backbone # Override the backbone
self.backbone = LambdaLayer(self._apply_relevances, self.backbone = LambdaLayer(self._apply_relevances, name="relevance scaling")
name="relevance scaling")
def _apply_relevances(self, x): def _apply_relevances(self, x):
return x @ torch.diag(self._relevances) return x @ torch.diag(self._relevances)
@ -211,8 +214,9 @@ class SiameseGMLVQ(SiameseGLVQ):
super().__init__(hparams, **kwargs) super().__init__(hparams, **kwargs)
# Override the backbone # Override the backbone
omega_initializer = kwargs.get("omega_initializer", omega_initializer = kwargs.get(
EyeLinearTransformInitializer()) "omega_initializer", EyeLinearTransformInitializer()
)
self.backbone = LinearTransform( self.backbone = LinearTransform(
self.hparams["input_dim"], self.hparams["input_dim"],
self.hparams["latent_dim"], self.hparams["latent_dim"],
@ -232,48 +236,46 @@ class SiameseGMLVQ(SiameseGLVQ):
class GMLMLVQ(GLVQ): class GMLMLVQ(GLVQ):
"""Generalized Multi-Layer Matrix Learning Vector Quantization. """Generalized Multi-Layer Matrix Learning Vector Quantization.
Masks are applied to the omega layers to achieve sparsity and constrain
learning to certain items of each omega.
Implemented as a regular GLVQ network that simply uses a different distance Implemented as a regular GLVQ network that simply uses a different distance
function. This makes it easier to implement a localized variant. function. This makes it easier to implement a localized variant.
""" """
# Parameters # Parameters
_omega_0: torch.Tensor _omegas: list[torch.Tensor]
_omega_1: torch.Tensor masks: list[torch.Tensor]
def __init__(self, hparams, **kwargs): def __init__(self, hparams, **kwargs):
distance_fn = kwargs.pop("distance_fn", ML_omega_distance) distance_fn = kwargs.pop("distance_fn", ML_omega_distance)
super().__init__(hparams, distance_fn=distance_fn, **kwargs) super().__init__(hparams, distance_fn=distance_fn, **kwargs)
# Additional parameters # Additional parameters
omega_initializer = kwargs.get("omega_initializer")
masks = kwargs.get("masks") masks = kwargs.get("masks")
omega_0 = LLTI(masks[0]).generate(1, 1) for i, _mask in enumerate(masks):
omega_1 = LLTI(masks[1]).generate(1, 1) self.register_buffer(f"_mask_{i}", _mask)
self.register_parameter("_omega_0", Parameter(omega_0)) self._masks = [self.__getattr__(f"_mask_{i}") for i,_ in enumerate(masks)]
self.register_parameter("_omega_1", Parameter(omega_1)) self._omegas = ParameterList([LLTI(mask).generate(1, 1) for mask in masks])
self.mask_0 = masks[0]
self.mask_1 = masks[1]
@property @property
def omega_matrices(self): def omega_matrices(self):
return [self._omega_0.detach().cpu(), self._omega_1.detach().cpu()] return [_omega.detach().cpu() for _omega in self._omegas]
@property @property
def lambda_matrix(self): def lambda_matrix(self):
# TODO update to respective lambda calculation rules.
omega = self._omega.detach() # (input_dim, latent_dim) omega = self._omega.detach() # (input_dim, latent_dim)
lam = omega @ omega.T lam = omega @ omega.T
return lam.detach().cpu() return lam.detach().cpu()
def compute_distances(self, x): def compute_distances(self, x):
protos, _ = self.proto_layer() protos, _ = self.proto_layer()
distances = self.distance_layer(x, protos, self._omega_0, distances = self.distance_layer(x, protos, self._omegas, self._masks)
self._omega_1, self.mask_0,
self.mask_1)
return distances return distances
def extra_repr(self): def extra_repr(self):
return f"(omega): (shape: {tuple(self._omega.shape)})" return f"(omegas): (shapes: {[tuple(_omega.shape) for _omega in self._omegas]})"
class GMLVQ(GLVQ): class GMLVQ(GLVQ):
@ -292,10 +294,12 @@ class GMLVQ(GLVQ):
super().__init__(hparams, distance_fn=distance_fn, **kwargs) super().__init__(hparams, distance_fn=distance_fn, **kwargs)
# Additional parameters # Additional parameters
omega_initializer = kwargs.get("omega_initializer", omega_initializer = kwargs.get(
EyeLinearTransformInitializer()) "omega_initializer", EyeLinearTransformInitializer()
omega = omega_initializer.generate(self.hparams["input_dim"], )
self.hparams["latent_dim"]) omega = omega_initializer.generate(
self.hparams["input_dim"], self.hparams["latent_dim"]
)
self.register_parameter("_omega", Parameter(omega)) self.register_parameter("_omega", Parameter(omega))
@property @property