Add Local-Matrix LVQ

Also remove the use of `self.distance_fn` in favor of `self.distance_layer`.
This commit is contained in:
Jensun Ravichandran
2021-06-01 23:44:16 +02:00
parent 5ec2dd47cd
commit 757f4e980d
2 changed files with 96 additions and 54 deletions

View File

@@ -4,12 +4,17 @@ import torch
import torchmetrics
from prototorch.components import LabeledComponents
from prototorch.functions.activations import get_activation
from prototorch.functions.competitions import stratified_min, wtac
from prototorch.functions.distances import (euclidean_distance, omega_distance,
sed)
from prototorch.functions.competitions import wtac
from prototorch.functions.distances import (
euclidean_distance,
lomega_distance,
omega_distance,
squared_euclidean_distance,
)
from prototorch.functions.helper import get_flat
from prototorch.functions.losses import glvq_loss, lvq1_loss, lvq21_loss
from prototorch.modules import LambdaLayer
from torch.nn.parameter import Parameter
from .abstract import AbstractPrototypeModel, PrototypeImageModel
@@ -17,17 +22,19 @@ from .abstract import AbstractPrototypeModel, PrototypeImageModel
class GLVQ(AbstractPrototypeModel):
"""Generalized Learning Vector Quantization."""
def __init__(self, hparams, **kwargs):
super().__init__()
# Hyperparameters
self.save_hyperparameters(hparams) # Default Values
self.save_hyperparameters(hparams)
# Defaults
self.hparams.setdefault("transfer_fn", "identity")
self.hparams.setdefault("transfer_beta", 10.0)
self.hparams.setdefault("lr", 0.01)
distance_fn = kwargs.get("distance_fn", euclidean_distance)
tranfer_fn = get_activation(self.hparams.transfer_fn)
transfer_fn = get_activation(self.hparams.transfer_fn)
# Layers
self.proto_layer = LabeledComponents(
@@ -35,7 +42,7 @@ class GLVQ(AbstractPrototypeModel):
initializer=self.prototype_initializer(**kwargs))
self.distance_layer = LambdaLayer(distance_fn)
self.transfer_layer = LambdaLayer(tranfer_fn)
self.transfer_layer = LambdaLayer(transfer_fn)
self.loss = LambdaLayer(glvq_loss)
self.optimizer = kwargs.get("optimizer", torch.optim.Adam)
@@ -123,8 +130,12 @@ class GLVQ(AbstractPrototypeModel):
# def predict_step(self, batch, batch_idx, dataloader_idx=None):
# pass
def increase_prototypes(self, initializer, distribution):
self.proto_layer.increase_components(initializer, distribution)
def add_prototypes(self, initializer, distribution):
self.proto_layer.add_components(initializer, distribution)
self.trainer.accelerator_backend.setup_optimizers(self.trainer)
def remove_prototypes(self, indices):
self.proto_layer.remove_components(indices)
self.trainer.accelerator_backend.setup_optimizers(self.trainer)
def __repr__(self):
@@ -145,10 +156,10 @@ class SiameseGLVQ(GLVQ):
backbone=torch.nn.Identity(),
both_path_gradients=False,
**kwargs):
super().__init__(hparams, **kwargs)
distance_fn = kwargs.pop("distance_fn", squared_euclidean_distance)
super().__init__(hparams, distance_fn=distance_fn, **kwargs)
self.backbone = backbone
self.both_path_gradients = both_path_gradients
self.distance_fn = kwargs.get("distance_fn", sed)
def configure_optimizers(self):
proto_opt = self.optimizer(self.proto_layer.parameters(),
@@ -168,7 +179,7 @@ class SiameseGLVQ(GLVQ):
self.backbone.requires_grad_(self.both_path_gradients)
latent_protos = self.backbone(protos)
self.backbone.requires_grad_(True)
distances = self.distance_fn(latent_x, latent_protos)
distances = self.distance_layer(latent_x, latent_protos)
return distances
def predict_latent(self, x, map_protos=True):
@@ -183,39 +194,44 @@ class SiameseGLVQ(GLVQ):
protos, plabels = self.proto_layer()
if map_protos:
protos = self.backbone(protos)
d = self.distance_fn(x, protos)
d = self.distance_layer(x, protos)
y_pred = wtac(d, plabels)
return y_pred
class GRLVQ(SiameseGLVQ):
"""Generalized Relevance Learning Vector Quantization."""
def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs)
self.relevances = torch.nn.parameter.Parameter(
torch.ones(self.hparams.input_dim))
"""Generalized Relevance Learning Vector Quantization.
# Overwrite backbone
self.backbone = self._backbone
TODO Make a RelevanceLayer. `bb_lr` is ignored otherwise.
"""
def __init__(self, hparams, **kwargs):
distance_fn = kwargs.pop("distance_fn", omega_distance)
super().__init__(hparams, distance_fn=distance_fn, **kwargs)
relevances = torch.ones(self.hparams.input_dim, device=self.device)
self.register_parameter("_relevances", Parameter(relevances))
# Override the backbone.
self.backbone = LambdaLayer(lambda x: x @ torch.diag(self.relevances),
name="relevances")
@property
def relevance_profile(self):
return self.relevances.detach().cpu()
def _backbone(self, x):
"""Namespace hook for the visualization callbacks to work."""
return x @ torch.diag(self.relevances)
def _forward(self, x):
protos, _ = self.proto_layer()
distances = omega_distance(x, protos, torch.diag(self.relevances))
distances = self.distance_layer(x, protos, torch.diag(self.relevances))
return distances
class GMLVQ(SiameseGLVQ):
"""Generalized Matrix Learning Vector Quantization."""
class SiameseGMLVQ(SiameseGLVQ):
"""Generalized Matrix Learning Vector Quantization.
Implemented as a Siamese network with a linear transformation backbone.
"""
def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs)
# Override the backbone.
self.backbone = torch.nn.Linear(self.hparams.input_dim,
self.hparams.latent_dim,
bias=False)
@@ -230,16 +246,6 @@ class GMLVQ(SiameseGLVQ):
lam = omega.T @ omega
return lam.detach().cpu()
def show_lambda(self):
import matplotlib.pyplot as plt
title = "Lambda matrix"
plt.figure(title)
plt.title(title)
plt.imshow(self.lambda_matrix, cmap="gray")
plt.axis("off")
plt.colorbar()
plt.show(block=True)
def _forward(self, x):
protos, _ = self.proto_layer()
x, protos = get_flat(x, protos)
@@ -247,7 +253,7 @@ class GMLVQ(SiameseGLVQ):
self.backbone.requires_grad_(self.both_path_gradients)
latent_protos = self.backbone(protos)
self.backbone.requires_grad_(True)
distances = self.distance_fn(latent_x, latent_protos)
distances = self.distance_layer(latent_x, latent_protos)
return distances
@@ -263,24 +269,47 @@ class LVQMLN(SiameseGLVQ):
def _forward(self, x):
latent_protos, _ = self.proto_layer()
latent_x = self.backbone(x)
distances = self.distance_fn(latent_x, latent_protos)
distances = self.distance_layer(latent_x, latent_protos)
return distances
class CELVQ(GLVQ):
"""Cross-Entropy Learning Vector Quantization."""
def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs)
self.loss = torch.nn.CrossEntropyLoss()
class GMLVQ(GLVQ):
"""Generalized Matrix Learning Vector Quantization.
def shared_step(self, batch, batch_idx, optimizer_idx=None):
x, y = batch
out = self._forward(x) # [None, num_protos]
plabels = self.proto_layer.component_labels
probs = -1.0 * stratified_min(out, plabels) # [None, num_classes]
batch_loss = self.loss(out, y.long())
loss = batch_loss.sum(dim=0)
return out, loss
Implemented as a regular GLVQ network that simply uses a different distance
function.
"""
def __init__(self, hparams, **kwargs):
distance_fn = kwargs.pop("distance_fn", omega_distance)
super().__init__(hparams, distance_fn=distance_fn, **kwargs)
omega = torch.randn(self.hparams.input_dim,
self.hparams.latent_dim,
device=self.device)
self.register_parameter("_omega", Parameter(omega))
def _forward(self, x):
protos, _ = self.proto_layer()
distances = self.distance_layer(x, protos, self._omega)
return distances
def extra_repr(self):
return f"(omega): (shape: {tuple(self._omega.shape)})"
class LGMLVQ(GMLVQ):
"""Localized and Generalized Matrix Learning Vector Quantization."""
def __init__(self, hparams, **kwargs):
distance_fn = kwargs.pop("distance_fn", lomega_distance)
super().__init__(hparams, distance_fn=distance_fn, **kwargs)
# Re-register `_omega` to override the one from the super class.
omega = torch.randn(
self.num_prototypes,
self.hparams.input_dim,
self.hparams.latent_dim,
device=self.device,
)
self.register_parameter("_omega", Parameter(omega))
class GLVQ1(GLVQ):