Compare commits

...

5 Commits

Author SHA1 Message Date
julius
adafb49985
masks -> ParameterList(requires_grad=False) 2023-11-07 19:17:43 +01:00
julius
78f8b6cc00
remove accidental LiteralString import 2023-11-07 18:52:51 +01:00
julius
c6f718a1d4
GMLMLVQ: allow for 2 or more omega layers 2023-11-07 16:44:13 +01:00
julius
1786031b4e
adjust omega_matrix property 2023-11-06 16:32:57 +01:00
julius
824dfced92
Implement a prototypical 2-layer version of GMLVQ 2023-11-03 14:59:00 +01:00

View File

@ -1,13 +1,15 @@
"""Models based on the GLVQ framework."""
import torch
from numpy.typing import NDArray
from prototorch.core.competitions import wtac
from prototorch.core.distances import (
ML_omega_distance,
lomega_distance,
omega_distance,
squared_euclidean_distance,
)
from prototorch.core.initializers import EyeLinearTransformInitializer
from prototorch.core.initializers import LLTI, EyeLinearTransformInitializer
from prototorch.core.losses import (
GLVQLoss,
lvq1_loss,
@ -15,7 +17,7 @@ from prototorch.core.losses import (
)
from prototorch.core.transforms import LinearTransform
from prototorch.nn.wrappers import LambdaLayer, LossLayer
from torch.nn.parameter import Parameter
from torch.nn import Parameter, ParameterList
from .abstract import ImagePrototypesMixin, SupervisedPrototypeModel
from .extras import ltangent_distance, orthogonalization
@ -45,26 +47,28 @@ class GLVQ(SupervisedPrototypeModel):
def initialize_prototype_win_ratios(self):
self.register_buffer(
"prototype_win_ratios",
torch.zeros(self.num_prototypes, device=self.device))
"prototype_win_ratios", torch.zeros(self.num_prototypes, device=self.device)
)
def on_train_epoch_start(self):
self.initialize_prototype_win_ratios()
def log_prototype_win_ratios(self, distances):
batch_size = len(distances)
prototype_wc = torch.zeros(self.num_prototypes,
dtype=torch.long,
device=self.device)
wi, wc = torch.unique(distances.min(dim=-1).indices,
sorted=True,
return_counts=True)
prototype_wc = torch.zeros(
self.num_prototypes, dtype=torch.long, device=self.device
)
wi, wc = torch.unique(
distances.min(dim=-1).indices, sorted=True, return_counts=True
)
prototype_wc[wi] = wc
prototype_wr = prototype_wc / batch_size
self.prototype_win_ratios = torch.vstack([
self.prototype_win_ratios = torch.vstack(
[
self.prototype_win_ratios,
prototype_wr,
])
]
)
def shared_step(self, batch, batch_idx):
x, y = batch
@ -109,11 +113,9 @@ class SiameseGLVQ(GLVQ):
"""
def __init__(self,
hparams,
backbone=torch.nn.Identity(),
both_path_gradients=False,
**kwargs):
def __init__(
self, hparams, backbone=torch.nn.Identity(), both_path_gradients=False, **kwargs
):
distance_fn = kwargs.pop("distance_fn", squared_euclidean_distance)
super().__init__(hparams, distance_fn=distance_fn, **kwargs)
self.backbone = backbone
@ -175,6 +177,7 @@ class GRLVQ(SiameseGLVQ):
TODO Make a RelevanceLayer. `bb_lr` is ignored otherwise.
"""
_relevances: torch.Tensor
def __init__(self, hparams, **kwargs):
@ -185,8 +188,7 @@ class GRLVQ(SiameseGLVQ):
self.register_parameter("_relevances", Parameter(relevances))
# Override the backbone
self.backbone = LambdaLayer(self._apply_relevances,
name="relevance scaling")
self.backbone = LambdaLayer(self._apply_relevances, name="relevance scaling")
def _apply_relevances(self, x):
return x @ torch.diag(self._relevances)
@ -210,8 +212,9 @@ class SiameseGMLVQ(SiameseGLVQ):
super().__init__(hparams, **kwargs)
# Override the backbone
omega_initializer = kwargs.get("omega_initializer",
EyeLinearTransformInitializer())
omega_initializer = kwargs.get(
"omega_initializer", EyeLinearTransformInitializer()
)
self.backbone = LinearTransform(
self.hparams["input_dim"],
self.hparams["latent_dim"],
@ -229,6 +232,49 @@ class SiameseGMLVQ(SiameseGLVQ):
return lam.detach().cpu()
class GMLMLVQ(GLVQ):
"""Generalized Multi-Layer Matrix Learning Vector Quantization.
Masks are applied to the omega layers to achieve sparsity and constrain
learning to certain items of each omega.
Implemented as a regular GLVQ network that simply uses a different distance
function. This makes it easier to implement a localized variant.
"""
# Parameters
_omegas: list[torch.Tensor]
masks: list[torch.Tensor]
def __init__(self, hparams, **kwargs):
distance_fn = kwargs.pop("distance_fn", ML_omega_distance)
super().__init__(hparams, distance_fn=distance_fn, **kwargs)
# Additional parameters
self._masks = ParameterList(
[Parameter(mask, requires_grad=False) for mask in kwargs.get("masks")]
)
self._omegas = ParameterList([LLTI(mask).generate(1, 1) for mask in self._masks])
@property
def omega_matrices(self):
return [_omega.detach().cpu() for _omega in self._omegas]
@property
def lambda_matrix(self):
# TODO update to respective lambda calculation rules.
omega = self._omega.detach() # (input_dim, latent_dim)
lam = omega @ omega.T
return lam.detach().cpu()
def compute_distances(self, x):
protos, _ = self.proto_layer()
distances = self.distance_layer(x, protos, self._omegas, self._masks)
return distances
def extra_repr(self):
return f"(omegas): (shapes: {[tuple(_omega.shape) for _omega in self._omegas]})"
class GMLVQ(GLVQ):
"""Generalized Matrix Learning Vector Quantization.
@ -245,10 +291,12 @@ class GMLVQ(GLVQ):
super().__init__(hparams, distance_fn=distance_fn, **kwargs)
# Additional parameters
omega_initializer = kwargs.get("omega_initializer",
EyeLinearTransformInitializer())
omega = omega_initializer.generate(self.hparams["input_dim"],
self.hparams["latent_dim"])
omega_initializer = kwargs.get(
"omega_initializer", EyeLinearTransformInitializer()
)
omega = omega_initializer.generate(
self.hparams["input_dim"], self.hparams["latent_dim"]
)
self.register_parameter("_omega", Parameter(omega))
@property