GMLMLVQ: allow for 2 or more omega layers
This commit is contained in:
parent
1786031b4e
commit
c6f718a1d4
@ -1,14 +1,17 @@
|
|||||||
"""Models based on the GLVQ framework."""
|
"""Models based on the GLVQ framework."""
|
||||||
|
|
||||||
|
from typing import LiteralString
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from numpy.typing import NDArray
|
||||||
from prototorch.core.competitions import wtac
|
from prototorch.core.competitions import wtac
|
||||||
from prototorch.core.distances import (
|
from prototorch.core.distances import (
|
||||||
|
ML_omega_distance,
|
||||||
lomega_distance,
|
lomega_distance,
|
||||||
omega_distance,
|
omega_distance,
|
||||||
ML_omega_distance,
|
|
||||||
squared_euclidean_distance,
|
squared_euclidean_distance,
|
||||||
)
|
)
|
||||||
from prototorch.core.initializers import (EyeLinearTransformInitializer, LLTI)
|
from prototorch.core.initializers import LLTI, EyeLinearTransformInitializer
|
||||||
from prototorch.core.losses import (
|
from prototorch.core.losses import (
|
||||||
GLVQLoss,
|
GLVQLoss,
|
||||||
lvq1_loss,
|
lvq1_loss,
|
||||||
@ -16,7 +19,7 @@ from prototorch.core.losses import (
|
|||||||
)
|
)
|
||||||
from prototorch.core.transforms import LinearTransform
|
from prototorch.core.transforms import LinearTransform
|
||||||
from prototorch.nn.wrappers import LambdaLayer, LossLayer
|
from prototorch.nn.wrappers import LambdaLayer, LossLayer
|
||||||
from torch.nn.parameter import Parameter
|
from torch.nn import Parameter, ParameterList
|
||||||
|
|
||||||
from .abstract import ImagePrototypesMixin, SupervisedPrototypeModel
|
from .abstract import ImagePrototypesMixin, SupervisedPrototypeModel
|
||||||
from .extras import ltangent_distance, orthogonalization
|
from .extras import ltangent_distance, orthogonalization
|
||||||
@ -46,26 +49,28 @@ class GLVQ(SupervisedPrototypeModel):
|
|||||||
|
|
||||||
def initialize_prototype_win_ratios(self):
|
def initialize_prototype_win_ratios(self):
|
||||||
self.register_buffer(
|
self.register_buffer(
|
||||||
"prototype_win_ratios",
|
"prototype_win_ratios", torch.zeros(self.num_prototypes, device=self.device)
|
||||||
torch.zeros(self.num_prototypes, device=self.device))
|
)
|
||||||
|
|
||||||
def on_train_epoch_start(self):
|
def on_train_epoch_start(self):
|
||||||
self.initialize_prototype_win_ratios()
|
self.initialize_prototype_win_ratios()
|
||||||
|
|
||||||
def log_prototype_win_ratios(self, distances):
|
def log_prototype_win_ratios(self, distances):
|
||||||
batch_size = len(distances)
|
batch_size = len(distances)
|
||||||
prototype_wc = torch.zeros(self.num_prototypes,
|
prototype_wc = torch.zeros(
|
||||||
dtype=torch.long,
|
self.num_prototypes, dtype=torch.long, device=self.device
|
||||||
device=self.device)
|
)
|
||||||
wi, wc = torch.unique(distances.min(dim=-1).indices,
|
wi, wc = torch.unique(
|
||||||
sorted=True,
|
distances.min(dim=-1).indices, sorted=True, return_counts=True
|
||||||
return_counts=True)
|
)
|
||||||
prototype_wc[wi] = wc
|
prototype_wc[wi] = wc
|
||||||
prototype_wr = prototype_wc / batch_size
|
prototype_wr = prototype_wc / batch_size
|
||||||
self.prototype_win_ratios = torch.vstack([
|
self.prototype_win_ratios = torch.vstack(
|
||||||
self.prototype_win_ratios,
|
[
|
||||||
prototype_wr,
|
self.prototype_win_ratios,
|
||||||
])
|
prototype_wr,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
def shared_step(self, batch, batch_idx):
|
def shared_step(self, batch, batch_idx):
|
||||||
x, y = batch
|
x, y = batch
|
||||||
@ -110,11 +115,9 @@ class SiameseGLVQ(GLVQ):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(
|
||||||
hparams,
|
self, hparams, backbone=torch.nn.Identity(), both_path_gradients=False, **kwargs
|
||||||
backbone=torch.nn.Identity(),
|
):
|
||||||
both_path_gradients=False,
|
|
||||||
**kwargs):
|
|
||||||
distance_fn = kwargs.pop("distance_fn", squared_euclidean_distance)
|
distance_fn = kwargs.pop("distance_fn", squared_euclidean_distance)
|
||||||
super().__init__(hparams, distance_fn=distance_fn, **kwargs)
|
super().__init__(hparams, distance_fn=distance_fn, **kwargs)
|
||||||
self.backbone = backbone
|
self.backbone = backbone
|
||||||
@ -176,6 +179,7 @@ class GRLVQ(SiameseGLVQ):
|
|||||||
TODO Make a RelevanceLayer. `bb_lr` is ignored otherwise.
|
TODO Make a RelevanceLayer. `bb_lr` is ignored otherwise.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_relevances: torch.Tensor
|
_relevances: torch.Tensor
|
||||||
|
|
||||||
def __init__(self, hparams, **kwargs):
|
def __init__(self, hparams, **kwargs):
|
||||||
@ -186,8 +190,7 @@ class GRLVQ(SiameseGLVQ):
|
|||||||
self.register_parameter("_relevances", Parameter(relevances))
|
self.register_parameter("_relevances", Parameter(relevances))
|
||||||
|
|
||||||
# Override the backbone
|
# Override the backbone
|
||||||
self.backbone = LambdaLayer(self._apply_relevances,
|
self.backbone = LambdaLayer(self._apply_relevances, name="relevance scaling")
|
||||||
name="relevance scaling")
|
|
||||||
|
|
||||||
def _apply_relevances(self, x):
|
def _apply_relevances(self, x):
|
||||||
return x @ torch.diag(self._relevances)
|
return x @ torch.diag(self._relevances)
|
||||||
@ -211,8 +214,9 @@ class SiameseGMLVQ(SiameseGLVQ):
|
|||||||
super().__init__(hparams, **kwargs)
|
super().__init__(hparams, **kwargs)
|
||||||
|
|
||||||
# Override the backbone
|
# Override the backbone
|
||||||
omega_initializer = kwargs.get("omega_initializer",
|
omega_initializer = kwargs.get(
|
||||||
EyeLinearTransformInitializer())
|
"omega_initializer", EyeLinearTransformInitializer()
|
||||||
|
)
|
||||||
self.backbone = LinearTransform(
|
self.backbone = LinearTransform(
|
||||||
self.hparams["input_dim"],
|
self.hparams["input_dim"],
|
||||||
self.hparams["latent_dim"],
|
self.hparams["latent_dim"],
|
||||||
@ -232,48 +236,46 @@ class SiameseGMLVQ(SiameseGLVQ):
|
|||||||
|
|
||||||
class GMLMLVQ(GLVQ):
|
class GMLMLVQ(GLVQ):
|
||||||
"""Generalized Multi-Layer Matrix Learning Vector Quantization.
|
"""Generalized Multi-Layer Matrix Learning Vector Quantization.
|
||||||
|
Masks are applied to the omega layers to achieve sparsity and constrain
|
||||||
|
learning to certain items of each omega.
|
||||||
|
|
||||||
Implemented as a regular GLVQ network that simply uses a different distance
|
Implemented as a regular GLVQ network that simply uses a different distance
|
||||||
function. This makes it easier to implement a localized variant.
|
function. This makes it easier to implement a localized variant.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Parameters
|
# Parameters
|
||||||
_omega_0: torch.Tensor
|
_omegas: list[torch.Tensor]
|
||||||
_omega_1: torch.Tensor
|
masks: list[torch.Tensor]
|
||||||
|
|
||||||
def __init__(self, hparams, **kwargs):
|
def __init__(self, hparams, **kwargs):
|
||||||
distance_fn = kwargs.pop("distance_fn", ML_omega_distance)
|
distance_fn = kwargs.pop("distance_fn", ML_omega_distance)
|
||||||
super().__init__(hparams, distance_fn=distance_fn, **kwargs)
|
super().__init__(hparams, distance_fn=distance_fn, **kwargs)
|
||||||
|
|
||||||
# Additional parameters
|
# Additional parameters
|
||||||
omega_initializer = kwargs.get("omega_initializer")
|
|
||||||
masks = kwargs.get("masks")
|
masks = kwargs.get("masks")
|
||||||
omega_0 = LLTI(masks[0]).generate(1, 1)
|
for i, _mask in enumerate(masks):
|
||||||
omega_1 = LLTI(masks[1]).generate(1, 1)
|
self.register_buffer(f"_mask_{i}", _mask)
|
||||||
self.register_parameter("_omega_0", Parameter(omega_0))
|
self._masks = [self.__getattr__(f"_mask_{i}") for i,_ in enumerate(masks)]
|
||||||
self.register_parameter("_omega_1", Parameter(omega_1))
|
self._omegas = ParameterList([LLTI(mask).generate(1, 1) for mask in masks])
|
||||||
self.mask_0 = masks[0]
|
|
||||||
self.mask_1 = masks[1]
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def omega_matrices(self):
|
def omega_matrices(self):
|
||||||
return [self._omega_0.detach().cpu(), self._omega_1.detach().cpu()]
|
return [_omega.detach().cpu() for _omega in self._omegas]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def lambda_matrix(self):
|
def lambda_matrix(self):
|
||||||
|
# TODO update to respective lambda calculation rules.
|
||||||
omega = self._omega.detach() # (input_dim, latent_dim)
|
omega = self._omega.detach() # (input_dim, latent_dim)
|
||||||
lam = omega @ omega.T
|
lam = omega @ omega.T
|
||||||
return lam.detach().cpu()
|
return lam.detach().cpu()
|
||||||
|
|
||||||
def compute_distances(self, x):
|
def compute_distances(self, x):
|
||||||
protos, _ = self.proto_layer()
|
protos, _ = self.proto_layer()
|
||||||
distances = self.distance_layer(x, protos, self._omega_0,
|
distances = self.distance_layer(x, protos, self._omegas, self._masks)
|
||||||
self._omega_1, self.mask_0,
|
|
||||||
self.mask_1)
|
|
||||||
return distances
|
return distances
|
||||||
|
|
||||||
def extra_repr(self):
|
def extra_repr(self):
|
||||||
return f"(omega): (shape: {tuple(self._omega.shape)})"
|
return f"(omegas): (shapes: {[tuple(_omega.shape) for _omega in self._omegas]})"
|
||||||
|
|
||||||
|
|
||||||
class GMLVQ(GLVQ):
|
class GMLVQ(GLVQ):
|
||||||
@ -292,10 +294,12 @@ class GMLVQ(GLVQ):
|
|||||||
super().__init__(hparams, distance_fn=distance_fn, **kwargs)
|
super().__init__(hparams, distance_fn=distance_fn, **kwargs)
|
||||||
|
|
||||||
# Additional parameters
|
# Additional parameters
|
||||||
omega_initializer = kwargs.get("omega_initializer",
|
omega_initializer = kwargs.get(
|
||||||
EyeLinearTransformInitializer())
|
"omega_initializer", EyeLinearTransformInitializer()
|
||||||
omega = omega_initializer.generate(self.hparams["input_dim"],
|
)
|
||||||
self.hparams["latent_dim"])
|
omega = omega_initializer.generate(
|
||||||
|
self.hparams["input_dim"], self.hparams["latent_dim"]
|
||||||
|
)
|
||||||
self.register_parameter("_omega", Parameter(omega))
|
self.register_parameter("_omega", Parameter(omega))
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
Loading…
Reference in New Issue
Block a user