Add Local-Matrix LVQ

Also remove the use of `self.distance_fn` in favor of `self.distance_layer`.
This commit is contained in:
Jensun Ravichandran 2021-06-01 23:44:16 +02:00
parent 5ec2dd47cd
commit 757f4e980d
2 changed files with 96 additions and 54 deletions

View File

@ -1,10 +1,23 @@
"""`models` plugin for the `prototorch` package."""
from importlib.metadata import PackageNotFoundError, version from importlib.metadata import PackageNotFoundError, version
from .probabilistic import LikelihoodRatioLVQ, RSLVQ
from .cbc import CBC, ImageCBC from .cbc import CBC, ImageCBC
from .glvq import (CELVQ, GLVQ, GLVQ1, GLVQ21, GMLVQ, GRLVQ, LVQMLN, ImageGLVQ, from .glvq import (
ImageGMLVQ, SiameseGLVQ) GLVQ,
GLVQ1,
GLVQ21,
GMLVQ,
GRLVQ,
LGMLVQ,
LVQMLN,
ImageGLVQ,
ImageGMLVQ,
SiameseGLVQ,
SiameseGMLVQ,
)
from .lvq import LVQ1, LVQ21, MedianLVQ from .lvq import LVQ1, LVQ21, MedianLVQ
from .probabilistic import CELVQ, RSLVQ, LikelihoodRatioLVQ
from .unsupervised import KNN, NeuralGas from .unsupervised import KNN, NeuralGas
from .vis import * from .vis import *

View File

@ -4,12 +4,17 @@ import torch
import torchmetrics import torchmetrics
from prototorch.components import LabeledComponents from prototorch.components import LabeledComponents
from prototorch.functions.activations import get_activation from prototorch.functions.activations import get_activation
from prototorch.functions.competitions import stratified_min, wtac from prototorch.functions.competitions import wtac
from prototorch.functions.distances import (euclidean_distance, omega_distance, from prototorch.functions.distances import (
sed) euclidean_distance,
lomega_distance,
omega_distance,
squared_euclidean_distance,
)
from prototorch.functions.helper import get_flat from prototorch.functions.helper import get_flat
from prototorch.functions.losses import glvq_loss, lvq1_loss, lvq21_loss from prototorch.functions.losses import glvq_loss, lvq1_loss, lvq21_loss
from prototorch.modules import LambdaLayer from prototorch.modules import LambdaLayer
from torch.nn.parameter import Parameter
from .abstract import AbstractPrototypeModel, PrototypeImageModel from .abstract import AbstractPrototypeModel, PrototypeImageModel
@ -17,17 +22,19 @@ from .abstract import AbstractPrototypeModel, PrototypeImageModel
class GLVQ(AbstractPrototypeModel): class GLVQ(AbstractPrototypeModel):
"""Generalized Learning Vector Quantization.""" """Generalized Learning Vector Quantization."""
def __init__(self, hparams, **kwargs): def __init__(self, hparams, **kwargs):
super().__init__() super().__init__()
# Hyperparameters # Hyperparameters
self.save_hyperparameters(hparams) # Default Values self.save_hyperparameters(hparams)
# Defaults
self.hparams.setdefault("transfer_fn", "identity") self.hparams.setdefault("transfer_fn", "identity")
self.hparams.setdefault("transfer_beta", 10.0) self.hparams.setdefault("transfer_beta", 10.0)
self.hparams.setdefault("lr", 0.01) self.hparams.setdefault("lr", 0.01)
distance_fn = kwargs.get("distance_fn", euclidean_distance) distance_fn = kwargs.get("distance_fn", euclidean_distance)
tranfer_fn = get_activation(self.hparams.transfer_fn) transfer_fn = get_activation(self.hparams.transfer_fn)
# Layers # Layers
self.proto_layer = LabeledComponents( self.proto_layer = LabeledComponents(
@ -35,7 +42,7 @@ class GLVQ(AbstractPrototypeModel):
initializer=self.prototype_initializer(**kwargs)) initializer=self.prototype_initializer(**kwargs))
self.distance_layer = LambdaLayer(distance_fn) self.distance_layer = LambdaLayer(distance_fn)
self.transfer_layer = LambdaLayer(tranfer_fn) self.transfer_layer = LambdaLayer(transfer_fn)
self.loss = LambdaLayer(glvq_loss) self.loss = LambdaLayer(glvq_loss)
self.optimizer = kwargs.get("optimizer", torch.optim.Adam) self.optimizer = kwargs.get("optimizer", torch.optim.Adam)
@ -123,8 +130,12 @@ class GLVQ(AbstractPrototypeModel):
# def predict_step(self, batch, batch_idx, dataloader_idx=None): # def predict_step(self, batch, batch_idx, dataloader_idx=None):
# pass # pass
def increase_prototypes(self, initializer, distribution): def add_prototypes(self, initializer, distribution):
self.proto_layer.increase_components(initializer, distribution) self.proto_layer.add_components(initializer, distribution)
self.trainer.accelerator_backend.setup_optimizers(self.trainer)
def remove_prototypes(self, indices):
self.proto_layer.remove_components(indices)
self.trainer.accelerator_backend.setup_optimizers(self.trainer) self.trainer.accelerator_backend.setup_optimizers(self.trainer)
def __repr__(self): def __repr__(self):
@ -145,10 +156,10 @@ class SiameseGLVQ(GLVQ):
backbone=torch.nn.Identity(), backbone=torch.nn.Identity(),
both_path_gradients=False, both_path_gradients=False,
**kwargs): **kwargs):
super().__init__(hparams, **kwargs) distance_fn = kwargs.pop("distance_fn", squared_euclidean_distance)
super().__init__(hparams, distance_fn=distance_fn, **kwargs)
self.backbone = backbone self.backbone = backbone
self.both_path_gradients = both_path_gradients self.both_path_gradients = both_path_gradients
self.distance_fn = kwargs.get("distance_fn", sed)
def configure_optimizers(self): def configure_optimizers(self):
proto_opt = self.optimizer(self.proto_layer.parameters(), proto_opt = self.optimizer(self.proto_layer.parameters(),
@ -168,7 +179,7 @@ class SiameseGLVQ(GLVQ):
self.backbone.requires_grad_(self.both_path_gradients) self.backbone.requires_grad_(self.both_path_gradients)
latent_protos = self.backbone(protos) latent_protos = self.backbone(protos)
self.backbone.requires_grad_(True) self.backbone.requires_grad_(True)
distances = self.distance_fn(latent_x, latent_protos) distances = self.distance_layer(latent_x, latent_protos)
return distances return distances
def predict_latent(self, x, map_protos=True): def predict_latent(self, x, map_protos=True):
@ -183,39 +194,44 @@ class SiameseGLVQ(GLVQ):
protos, plabels = self.proto_layer() protos, plabels = self.proto_layer()
if map_protos: if map_protos:
protos = self.backbone(protos) protos = self.backbone(protos)
d = self.distance_fn(x, protos) d = self.distance_layer(x, protos)
y_pred = wtac(d, plabels) y_pred = wtac(d, plabels)
return y_pred return y_pred
class GRLVQ(SiameseGLVQ): class GRLVQ(SiameseGLVQ):
"""Generalized Relevance Learning Vector Quantization.""" """Generalized Relevance Learning Vector Quantization.
def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs)
self.relevances = torch.nn.parameter.Parameter(
torch.ones(self.hparams.input_dim))
# Overwrite backbone TODO Make a RelevanceLayer. `bb_lr` is ignored otherwise.
self.backbone = self._backbone """
def __init__(self, hparams, **kwargs):
distance_fn = kwargs.pop("distance_fn", omega_distance)
super().__init__(hparams, distance_fn=distance_fn, **kwargs)
relevances = torch.ones(self.hparams.input_dim, device=self.device)
self.register_parameter("_relevances", Parameter(relevances))
# Override the backbone.
self.backbone = LambdaLayer(lambda x: x @ torch.diag(self.relevances),
name="relevances")
@property @property
def relevance_profile(self): def relevance_profile(self):
return self.relevances.detach().cpu() return self.relevances.detach().cpu()
def _backbone(self, x):
"""Namespace hook for the visualization callbacks to work."""
return x @ torch.diag(self.relevances)
def _forward(self, x): def _forward(self, x):
protos, _ = self.proto_layer() protos, _ = self.proto_layer()
distances = omega_distance(x, protos, torch.diag(self.relevances)) distances = self.distance_layer(x, protos, torch.diag(self.relevances))
return distances return distances
class GMLVQ(SiameseGLVQ): class SiameseGMLVQ(SiameseGLVQ):
"""Generalized Matrix Learning Vector Quantization.""" """Generalized Matrix Learning Vector Quantization.
Implemented as a Siamese network with a linear transformation backbone.
"""
def __init__(self, hparams, **kwargs): def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs) super().__init__(hparams, **kwargs)
# Override the backbone.
self.backbone = torch.nn.Linear(self.hparams.input_dim, self.backbone = torch.nn.Linear(self.hparams.input_dim,
self.hparams.latent_dim, self.hparams.latent_dim,
bias=False) bias=False)
@ -230,16 +246,6 @@ class GMLVQ(SiameseGLVQ):
lam = omega.T @ omega lam = omega.T @ omega
return lam.detach().cpu() return lam.detach().cpu()
def show_lambda(self):
import matplotlib.pyplot as plt
title = "Lambda matrix"
plt.figure(title)
plt.title(title)
plt.imshow(self.lambda_matrix, cmap="gray")
plt.axis("off")
plt.colorbar()
plt.show(block=True)
def _forward(self, x): def _forward(self, x):
protos, _ = self.proto_layer() protos, _ = self.proto_layer()
x, protos = get_flat(x, protos) x, protos = get_flat(x, protos)
@ -247,7 +253,7 @@ class GMLVQ(SiameseGLVQ):
self.backbone.requires_grad_(self.both_path_gradients) self.backbone.requires_grad_(self.both_path_gradients)
latent_protos = self.backbone(protos) latent_protos = self.backbone(protos)
self.backbone.requires_grad_(True) self.backbone.requires_grad_(True)
distances = self.distance_fn(latent_x, latent_protos) distances = self.distance_layer(latent_x, latent_protos)
return distances return distances
@ -263,24 +269,47 @@ class LVQMLN(SiameseGLVQ):
def _forward(self, x): def _forward(self, x):
latent_protos, _ = self.proto_layer() latent_protos, _ = self.proto_layer()
latent_x = self.backbone(x) latent_x = self.backbone(x)
distances = self.distance_fn(latent_x, latent_protos) distances = self.distance_layer(latent_x, latent_protos)
return distances return distances
class CELVQ(GLVQ): class GMLVQ(GLVQ):
"""Cross-Entropy Learning Vector Quantization.""" """Generalized Matrix Learning Vector Quantization.
def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs)
self.loss = torch.nn.CrossEntropyLoss()
def shared_step(self, batch, batch_idx, optimizer_idx=None): Implemented as a regular GLVQ network that simply uses a different distance
x, y = batch function.
out = self._forward(x) # [None, num_protos]
plabels = self.proto_layer.component_labels """
probs = -1.0 * stratified_min(out, plabels) # [None, num_classes] def __init__(self, hparams, **kwargs):
batch_loss = self.loss(out, y.long()) distance_fn = kwargs.pop("distance_fn", omega_distance)
loss = batch_loss.sum(dim=0) super().__init__(hparams, distance_fn=distance_fn, **kwargs)
return out, loss omega = torch.randn(self.hparams.input_dim,
self.hparams.latent_dim,
device=self.device)
self.register_parameter("_omega", Parameter(omega))
def _forward(self, x):
protos, _ = self.proto_layer()
distances = self.distance_layer(x, protos, self._omega)
return distances
def extra_repr(self):
return f"(omega): (shape: {tuple(self._omega.shape)})"
class LGMLVQ(GMLVQ):
"""Localized and Generalized Matrix Learning Vector Quantization."""
def __init__(self, hparams, **kwargs):
distance_fn = kwargs.pop("distance_fn", lomega_distance)
super().__init__(hparams, distance_fn=distance_fn, **kwargs)
# Re-register `_omega` to override the one from the super class.
omega = torch.randn(
self.num_prototypes,
self.hparams.input_dim,
self.hparams.latent_dim,
device=self.device,
)
self.register_parameter("_omega", Parameter(omega))
class GLVQ1(GLVQ): class GLVQ1(GLVQ):