Add Local-Matrix LVQ

Also remove the use of `self.distance_fn` in favor of `self.distance_layer`.
This commit is contained in:
Jensun Ravichandran 2021-06-01 23:44:16 +02:00
parent 5ec2dd47cd
commit 757f4e980d
2 changed files with 96 additions and 54 deletions

View File

@ -1,10 +1,23 @@
"""`models` plugin for the `prototorch` package."""
from importlib.metadata import PackageNotFoundError, version
from .probabilistic import LikelihoodRatioLVQ, RSLVQ
from .cbc import CBC, ImageCBC
from .glvq import (CELVQ, GLVQ, GLVQ1, GLVQ21, GMLVQ, GRLVQ, LVQMLN, ImageGLVQ,
ImageGMLVQ, SiameseGLVQ)
from .glvq import (
GLVQ,
GLVQ1,
GLVQ21,
GMLVQ,
GRLVQ,
LGMLVQ,
LVQMLN,
ImageGLVQ,
ImageGMLVQ,
SiameseGLVQ,
SiameseGMLVQ,
)
from .lvq import LVQ1, LVQ21, MedianLVQ
from .probabilistic import CELVQ, RSLVQ, LikelihoodRatioLVQ
from .unsupervised import KNN, NeuralGas
from .vis import *

View File

@ -4,12 +4,17 @@ import torch
import torchmetrics
from prototorch.components import LabeledComponents
from prototorch.functions.activations import get_activation
from prototorch.functions.competitions import stratified_min, wtac
from prototorch.functions.distances import (euclidean_distance, omega_distance,
sed)
from prototorch.functions.competitions import wtac
from prototorch.functions.distances import (
euclidean_distance,
lomega_distance,
omega_distance,
squared_euclidean_distance,
)
from prototorch.functions.helper import get_flat
from prototorch.functions.losses import glvq_loss, lvq1_loss, lvq21_loss
from prototorch.modules import LambdaLayer
from torch.nn.parameter import Parameter
from .abstract import AbstractPrototypeModel, PrototypeImageModel
@ -17,17 +22,19 @@ from .abstract import AbstractPrototypeModel, PrototypeImageModel
class GLVQ(AbstractPrototypeModel):
"""Generalized Learning Vector Quantization."""
def __init__(self, hparams, **kwargs):
super().__init__()
# Hyperparameters
self.save_hyperparameters(hparams) # Default Values
self.save_hyperparameters(hparams)
# Defaults
self.hparams.setdefault("transfer_fn", "identity")
self.hparams.setdefault("transfer_beta", 10.0)
self.hparams.setdefault("lr", 0.01)
distance_fn = kwargs.get("distance_fn", euclidean_distance)
tranfer_fn = get_activation(self.hparams.transfer_fn)
transfer_fn = get_activation(self.hparams.transfer_fn)
# Layers
self.proto_layer = LabeledComponents(
@ -35,7 +42,7 @@ class GLVQ(AbstractPrototypeModel):
initializer=self.prototype_initializer(**kwargs))
self.distance_layer = LambdaLayer(distance_fn)
self.transfer_layer = LambdaLayer(tranfer_fn)
self.transfer_layer = LambdaLayer(transfer_fn)
self.loss = LambdaLayer(glvq_loss)
self.optimizer = kwargs.get("optimizer", torch.optim.Adam)
@ -123,8 +130,12 @@ class GLVQ(AbstractPrototypeModel):
# def predict_step(self, batch, batch_idx, dataloader_idx=None):
# pass
def increase_prototypes(self, initializer, distribution):
self.proto_layer.increase_components(initializer, distribution)
def add_prototypes(self, initializer, distribution):
self.proto_layer.add_components(initializer, distribution)
self.trainer.accelerator_backend.setup_optimizers(self.trainer)
def remove_prototypes(self, indices):
self.proto_layer.remove_components(indices)
self.trainer.accelerator_backend.setup_optimizers(self.trainer)
def __repr__(self):
@ -145,10 +156,10 @@ class SiameseGLVQ(GLVQ):
backbone=torch.nn.Identity(),
both_path_gradients=False,
**kwargs):
super().__init__(hparams, **kwargs)
distance_fn = kwargs.pop("distance_fn", squared_euclidean_distance)
super().__init__(hparams, distance_fn=distance_fn, **kwargs)
self.backbone = backbone
self.both_path_gradients = both_path_gradients
self.distance_fn = kwargs.get("distance_fn", sed)
def configure_optimizers(self):
proto_opt = self.optimizer(self.proto_layer.parameters(),
@ -168,7 +179,7 @@ class SiameseGLVQ(GLVQ):
self.backbone.requires_grad_(self.both_path_gradients)
latent_protos = self.backbone(protos)
self.backbone.requires_grad_(True)
distances = self.distance_fn(latent_x, latent_protos)
distances = self.distance_layer(latent_x, latent_protos)
return distances
def predict_latent(self, x, map_protos=True):
@ -183,39 +194,44 @@ class SiameseGLVQ(GLVQ):
protos, plabels = self.proto_layer()
if map_protos:
protos = self.backbone(protos)
d = self.distance_fn(x, protos)
d = self.distance_layer(x, protos)
y_pred = wtac(d, plabels)
return y_pred
class GRLVQ(SiameseGLVQ):
"""Generalized Relevance Learning Vector Quantization."""
def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs)
self.relevances = torch.nn.parameter.Parameter(
torch.ones(self.hparams.input_dim))
"""Generalized Relevance Learning Vector Quantization.
# Overwrite backbone
self.backbone = self._backbone
TODO Make a RelevanceLayer. `bb_lr` is ignored otherwise.
"""
def __init__(self, hparams, **kwargs):
distance_fn = kwargs.pop("distance_fn", omega_distance)
super().__init__(hparams, distance_fn=distance_fn, **kwargs)
relevances = torch.ones(self.hparams.input_dim, device=self.device)
self.register_parameter("_relevances", Parameter(relevances))
# Override the backbone.
self.backbone = LambdaLayer(lambda x: x @ torch.diag(self.relevances),
name="relevances")
@property
def relevance_profile(self):
return self.relevances.detach().cpu()
def _backbone(self, x):
"""Namespace hook for the visualization callbacks to work."""
return x @ torch.diag(self.relevances)
def _forward(self, x):
protos, _ = self.proto_layer()
distances = omega_distance(x, protos, torch.diag(self.relevances))
distances = self.distance_layer(x, protos, torch.diag(self.relevances))
return distances
class GMLVQ(SiameseGLVQ):
"""Generalized Matrix Learning Vector Quantization."""
class SiameseGMLVQ(SiameseGLVQ):
"""Generalized Matrix Learning Vector Quantization.
Implemented as a Siamese network with a linear transformation backbone.
"""
def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs)
# Override the backbone.
self.backbone = torch.nn.Linear(self.hparams.input_dim,
self.hparams.latent_dim,
bias=False)
@ -230,16 +246,6 @@ class GMLVQ(SiameseGLVQ):
lam = omega.T @ omega
return lam.detach().cpu()
def show_lambda(self):
import matplotlib.pyplot as plt
title = "Lambda matrix"
plt.figure(title)
plt.title(title)
plt.imshow(self.lambda_matrix, cmap="gray")
plt.axis("off")
plt.colorbar()
plt.show(block=True)
def _forward(self, x):
protos, _ = self.proto_layer()
x, protos = get_flat(x, protos)
@ -247,7 +253,7 @@ class GMLVQ(SiameseGLVQ):
self.backbone.requires_grad_(self.both_path_gradients)
latent_protos = self.backbone(protos)
self.backbone.requires_grad_(True)
distances = self.distance_fn(latent_x, latent_protos)
distances = self.distance_layer(latent_x, latent_protos)
return distances
@ -263,24 +269,47 @@ class LVQMLN(SiameseGLVQ):
def _forward(self, x):
latent_protos, _ = self.proto_layer()
latent_x = self.backbone(x)
distances = self.distance_fn(latent_x, latent_protos)
distances = self.distance_layer(latent_x, latent_protos)
return distances
class CELVQ(GLVQ):
"""Cross-Entropy Learning Vector Quantization."""
def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs)
self.loss = torch.nn.CrossEntropyLoss()
class GMLVQ(GLVQ):
"""Generalized Matrix Learning Vector Quantization.
def shared_step(self, batch, batch_idx, optimizer_idx=None):
x, y = batch
out = self._forward(x) # [None, num_protos]
plabels = self.proto_layer.component_labels
probs = -1.0 * stratified_min(out, plabels) # [None, num_classes]
batch_loss = self.loss(out, y.long())
loss = batch_loss.sum(dim=0)
return out, loss
Implemented as a regular GLVQ network that simply uses a different distance
function.
"""
def __init__(self, hparams, **kwargs):
distance_fn = kwargs.pop("distance_fn", omega_distance)
super().__init__(hparams, distance_fn=distance_fn, **kwargs)
omega = torch.randn(self.hparams.input_dim,
self.hparams.latent_dim,
device=self.device)
self.register_parameter("_omega", Parameter(omega))
def _forward(self, x):
protos, _ = self.proto_layer()
distances = self.distance_layer(x, protos, self._omega)
return distances
def extra_repr(self):
return f"(omega): (shape: {tuple(self._omega.shape)})"
class LGMLVQ(GMLVQ):
"""Localized and Generalized Matrix Learning Vector Quantization."""
def __init__(self, hparams, **kwargs):
distance_fn = kwargs.pop("distance_fn", lomega_distance)
super().__init__(hparams, distance_fn=distance_fn, **kwargs)
# Re-register `_omega` to override the one from the super class.
omega = torch.randn(
self.num_prototypes,
self.hparams.input_dim,
self.hparams.latent_dim,
device=self.device,
)
self.register_parameter("_omega", Parameter(omega))
class GLVQ1(GLVQ):