2021-05-25 18:37:34 +00:00
|
|
|
"""Models based on the GLVQ framework."""
|
|
|
|
|
2021-04-21 12:51:34 +00:00
|
|
|
import torch
|
2021-04-21 17:16:57 +00:00
|
|
|
import torchmetrics
|
2021-04-29 15:05:41 +00:00
|
|
|
from prototorch.components import LabeledComponents
|
2021-05-04 18:56:16 +00:00
|
|
|
from prototorch.functions.activations import get_activation
|
2021-05-27 15:40:16 +00:00
|
|
|
from prototorch.functions.competitions import stratified_min, wtac
|
2021-05-06 16:42:06 +00:00
|
|
|
from prototorch.functions.distances import (euclidean_distance, omega_distance,
|
2021-05-17 15:00:23 +00:00
|
|
|
sed)
|
2021-05-12 14:36:22 +00:00
|
|
|
from prototorch.functions.helper import get_flat
|
2021-05-27 15:40:16 +00:00
|
|
|
from prototorch.functions.losses import glvq_loss, lvq1_loss, lvq21_loss
|
2021-05-12 14:36:22 +00:00
|
|
|
|
|
|
|
from .abstract import AbstractPrototypeModel, PrototypeImageModel
|
2021-04-29 15:05:41 +00:00
|
|
|
|
2021-04-21 12:51:34 +00:00
|
|
|
|
2021-04-29 15:05:41 +00:00
|
|
|
class GLVQ(AbstractPrototypeModel):
|
2021-04-21 12:51:34 +00:00
|
|
|
"""Generalized Learning Vector Quantization."""
|
2021-04-21 19:59:19 +00:00
|
|
|
def __init__(self, hparams, **kwargs):
|
2021-04-21 12:51:34 +00:00
|
|
|
super().__init__()
|
2021-04-27 13:38:57 +00:00
|
|
|
|
2021-04-21 19:59:19 +00:00
|
|
|
self.save_hyperparameters(hparams)
|
2021-04-27 13:38:57 +00:00
|
|
|
|
2021-05-17 15:00:23 +00:00
|
|
|
self.distance_fn = kwargs.get("distance_fn", euclidean_distance)
|
2021-05-11 14:13:00 +00:00
|
|
|
self.optimizer = kwargs.get("optimizer", torch.optim.Adam)
|
|
|
|
|
2021-04-27 13:38:57 +00:00
|
|
|
# Default Values
|
2021-05-19 14:30:19 +00:00
|
|
|
self.hparams.setdefault("transfer_fn", "identity")
|
2021-05-04 18:56:16 +00:00
|
|
|
self.hparams.setdefault("transfer_beta", 10.0)
|
2021-05-18 17:49:16 +00:00
|
|
|
self.hparams.setdefault("lr", 0.01)
|
2021-04-27 13:38:57 +00:00
|
|
|
|
2021-04-29 15:05:41 +00:00
|
|
|
self.proto_layer = LabeledComponents(
|
2021-05-11 14:13:00 +00:00
|
|
|
distribution=self.hparams.distribution,
|
2021-05-21 15:11:27 +00:00
|
|
|
initializer=self.prototype_initializer(**kwargs))
|
2021-04-27 13:38:57 +00:00
|
|
|
|
2021-05-19 14:30:19 +00:00
|
|
|
self.transfer_fn = get_activation(self.hparams.transfer_fn)
|
|
|
|
self.acc_metric = torchmetrics.Accuracy()
|
2021-04-21 12:51:34 +00:00
|
|
|
|
2021-05-11 11:26:13 +00:00
|
|
|
self.loss = glvq_loss
|
|
|
|
|
2021-05-21 15:11:27 +00:00
|
|
|
def prototype_initializer(self, **kwargs):
|
|
|
|
return kwargs.get("prototype_initializer", None)
|
|
|
|
|
2021-04-21 12:51:34 +00:00
|
|
|
@property
|
|
|
|
def prototype_labels(self):
|
2021-05-06 16:42:06 +00:00
|
|
|
return self.proto_layer.component_labels.detach().cpu()
|
2021-04-21 19:35:52 +00:00
|
|
|
|
2021-05-20 15:36:00 +00:00
|
|
|
@property
|
|
|
|
def num_classes(self):
|
|
|
|
return len(self.proto_layer.distribution)
|
|
|
|
|
|
|
|
def _forward(self, x):
|
2021-04-29 15:05:41 +00:00
|
|
|
protos, _ = self.proto_layer()
|
2021-05-19 14:57:51 +00:00
|
|
|
distances = self.distance_fn(x, protos)
|
|
|
|
return distances
|
2021-04-21 12:51:34 +00:00
|
|
|
|
2021-05-20 15:36:00 +00:00
|
|
|
def forward(self, x):
|
|
|
|
distances = self._forward(x)
|
|
|
|
y_pred = self.predict_from_distances(distances)
|
|
|
|
y_pred = torch.eye(self.num_classes, device=self.device)[y_pred.int()]
|
|
|
|
return y_pred
|
|
|
|
|
2021-05-20 12:40:02 +00:00
|
|
|
def predict_from_distances(self, distances):
|
|
|
|
with torch.no_grad():
|
|
|
|
plabels = self.proto_layer.component_labels
|
|
|
|
y_pred = wtac(distances, plabels)
|
|
|
|
return y_pred
|
|
|
|
|
|
|
|
def predict(self, x):
|
2021-04-21 17:16:57 +00:00
|
|
|
with torch.no_grad():
|
2021-05-20 15:36:00 +00:00
|
|
|
distances = self._forward(x)
|
2021-05-20 12:40:02 +00:00
|
|
|
y_pred = self.predict_from_distances(distances)
|
|
|
|
return y_pred
|
2021-05-06 16:02:01 +00:00
|
|
|
|
2021-05-20 12:40:02 +00:00
|
|
|
def log_acc(self, distances, targets, tag):
|
|
|
|
preds = self.predict_from_distances(distances)
|
2021-05-19 14:30:19 +00:00
|
|
|
self.acc_metric(preds.int(), targets.int())
|
2021-04-29 21:37:22 +00:00
|
|
|
# `.int()` because FloatTensors are assumed to be class probabilities
|
2021-04-21 12:51:34 +00:00
|
|
|
|
2021-05-19 14:30:19 +00:00
|
|
|
self.log(tag,
|
|
|
|
self.acc_metric,
|
2021-04-29 21:37:22 +00:00
|
|
|
on_step=False,
|
|
|
|
on_epoch=True,
|
|
|
|
prog_bar=True,
|
|
|
|
logger=True)
|
|
|
|
|
2021-05-19 14:57:51 +00:00
|
|
|
def shared_step(self, batch, batch_idx, optimizer_idx=None):
|
|
|
|
x, y = batch
|
2021-05-20 15:36:00 +00:00
|
|
|
out = self._forward(x)
|
2021-05-18 17:49:16 +00:00
|
|
|
plabels = self.proto_layer.component_labels
|
2021-05-19 14:57:51 +00:00
|
|
|
mu = self.loss(out, y, prototype_labels=plabels)
|
|
|
|
batch_loss = self.transfer_fn(mu, beta=self.hparams.transfer_beta)
|
|
|
|
loss = batch_loss.sum(dim=0)
|
|
|
|
return out, loss
|
2021-05-18 17:49:16 +00:00
|
|
|
|
2021-05-19 14:57:51 +00:00
|
|
|
def training_step(self, batch, batch_idx, optimizer_idx=None):
|
|
|
|
out, train_loss = self.shared_step(batch, batch_idx, optimizer_idx)
|
2021-05-20 15:36:00 +00:00
|
|
|
self.log("train_loss", train_loss)
|
2021-05-19 14:57:51 +00:00
|
|
|
self.log_acc(out, batch[-1], tag="train_acc")
|
2021-05-19 14:30:19 +00:00
|
|
|
return train_loss
|
|
|
|
|
2021-05-19 14:57:51 +00:00
|
|
|
def validation_step(self, batch, batch_idx):
|
|
|
|
# `model.eval()` and `torch.no_grad()` handled by pl
|
2021-05-20 11:17:27 +00:00
|
|
|
out, val_loss = self.shared_step(batch, batch_idx)
|
2021-05-19 14:30:19 +00:00
|
|
|
self.log("val_loss", val_loss)
|
2021-05-19 14:57:51 +00:00
|
|
|
self.log_acc(out, batch[-1], tag="val_acc")
|
2021-05-19 14:30:19 +00:00
|
|
|
return val_loss
|
|
|
|
|
2021-05-19 14:57:51 +00:00
|
|
|
def test_step(self, batch, batch_idx):
|
|
|
|
# `model.eval()` and `torch.no_grad()` handled by pl
|
2021-05-20 11:17:27 +00:00
|
|
|
out, test_loss = self.shared_step(batch, batch_idx)
|
2021-05-20 12:03:31 +00:00
|
|
|
self.log_acc(out, batch[-1], tag="test_acc")
|
2021-05-20 12:20:23 +00:00
|
|
|
return test_loss
|
|
|
|
|
|
|
|
def test_epoch_end(self, outputs):
|
2021-05-20 12:40:02 +00:00
|
|
|
test_loss = 0.0
|
2021-05-20 12:20:23 +00:00
|
|
|
for batch_loss in outputs:
|
2021-05-20 12:40:02 +00:00
|
|
|
test_loss += batch_loss.item()
|
|
|
|
self.log("test_loss", test_loss)
|
2021-04-21 17:16:57 +00:00
|
|
|
|
2021-05-19 14:57:51 +00:00
|
|
|
# def predict_step(self, batch, batch_idx, dataloader_idx=None):
|
|
|
|
# pass
|
|
|
|
|
2021-05-18 17:49:16 +00:00
|
|
|
def __repr__(self):
|
|
|
|
super_repr = super().__repr__()
|
|
|
|
return f"{super_repr}"
|
|
|
|
|
2021-04-21 12:51:34 +00:00
|
|
|
|
2021-05-21 11:33:57 +00:00
|
|
|
class SiameseGLVQ(GLVQ):
|
2021-04-27 12:35:17 +00:00
|
|
|
"""GLVQ in a Siamese setting.
|
|
|
|
|
|
|
|
GLVQ model that applies an arbitrary transformation on the inputs and the
|
|
|
|
prototypes before computing the distances between them. The weights in the
|
|
|
|
transformation pipeline are only learned from the inputs.
|
2021-04-29 21:37:22 +00:00
|
|
|
|
2021-04-27 12:35:17 +00:00
|
|
|
"""
|
|
|
|
def __init__(self,
|
|
|
|
hparams,
|
2021-05-17 15:00:23 +00:00
|
|
|
backbone=torch.nn.Identity(),
|
|
|
|
both_path_gradients=False,
|
2021-04-27 12:35:17 +00:00
|
|
|
**kwargs):
|
|
|
|
super().__init__(hparams, **kwargs)
|
2021-05-17 15:00:23 +00:00
|
|
|
self.backbone = backbone
|
|
|
|
self.both_path_gradients = both_path_gradients
|
|
|
|
self.distance_fn = kwargs.get("distance_fn", sed)
|
2021-05-03 11:20:49 +00:00
|
|
|
|
2021-05-21 11:33:57 +00:00
|
|
|
def configure_optimizers(self):
|
|
|
|
proto_opt = self.optimizer(self.proto_layer.parameters(),
|
|
|
|
lr=self.hparams.proto_lr)
|
|
|
|
if list(self.backbone.parameters()):
|
|
|
|
# only add an optimizer is the backbone has trainable parameters
|
|
|
|
# otherwise, the next line fails
|
|
|
|
bb_opt = self.optimizer(self.backbone.parameters(),
|
|
|
|
lr=self.hparams.bb_lr)
|
|
|
|
return proto_opt, bb_opt
|
|
|
|
else:
|
|
|
|
return proto_opt
|
|
|
|
|
2021-05-20 15:36:00 +00:00
|
|
|
def _forward(self, x):
|
2021-04-29 15:05:41 +00:00
|
|
|
protos, _ = self.proto_layer()
|
2021-04-27 12:35:17 +00:00
|
|
|
latent_x = self.backbone(x)
|
2021-05-17 15:00:23 +00:00
|
|
|
self.backbone.requires_grad_(self.both_path_gradients)
|
|
|
|
latent_protos = self.backbone(protos)
|
|
|
|
self.backbone.requires_grad_(True)
|
2021-05-20 15:36:00 +00:00
|
|
|
distances = self.distance_fn(latent_x, latent_protos)
|
|
|
|
return distances
|
2021-04-27 12:35:17 +00:00
|
|
|
|
2021-05-21 11:33:57 +00:00
|
|
|
def predict_latent(self, x, map_protos=True):
|
|
|
|
"""Predict `x` assuming it is already embedded in the latent space.
|
|
|
|
|
|
|
|
Only the prototypes are embedded in the latent space using the
|
|
|
|
backbone.
|
|
|
|
|
|
|
|
"""
|
|
|
|
self.eval()
|
|
|
|
with torch.no_grad():
|
|
|
|
protos, plabels = self.proto_layer()
|
|
|
|
if map_protos:
|
|
|
|
protos = self.backbone(protos)
|
|
|
|
d = self.distance_fn(x, protos)
|
|
|
|
y_pred = wtac(d, plabels)
|
|
|
|
return y_pred
|
|
|
|
|
2021-04-29 21:37:22 +00:00
|
|
|
|
2021-05-20 15:36:00 +00:00
|
|
|
class GRLVQ(SiameseGLVQ):
|
2021-05-06 16:42:06 +00:00
|
|
|
"""Generalized Relevance Learning Vector Quantization."""
|
|
|
|
def __init__(self, hparams, **kwargs):
|
|
|
|
super().__init__(hparams, **kwargs)
|
|
|
|
self.relevances = torch.nn.parameter.Parameter(
|
|
|
|
torch.ones(self.hparams.input_dim))
|
|
|
|
|
2021-05-21 13:42:45 +00:00
|
|
|
# Overwrite backbone
|
|
|
|
self.backbone = self._backbone
|
|
|
|
|
2021-05-06 16:42:06 +00:00
|
|
|
@property
|
|
|
|
def relevance_profile(self):
|
|
|
|
return self.relevances.detach().cpu()
|
|
|
|
|
2021-05-21 13:42:45 +00:00
|
|
|
def _backbone(self, x):
|
2021-05-17 15:00:23 +00:00
|
|
|
"""Namespace hook for the visualization callbacks to work."""
|
|
|
|
return x @ torch.diag(self.relevances)
|
2021-05-06 16:42:06 +00:00
|
|
|
|
2021-05-20 15:36:00 +00:00
|
|
|
def _forward(self, x):
|
2021-05-17 15:00:23 +00:00
|
|
|
protos, _ = self.proto_layer()
|
2021-05-20 15:36:00 +00:00
|
|
|
distances = omega_distance(x, protos, torch.diag(self.relevances))
|
|
|
|
return distances
|
2021-05-06 16:42:06 +00:00
|
|
|
|
|
|
|
|
2021-05-20 15:36:00 +00:00
|
|
|
class GMLVQ(SiameseGLVQ):
|
2021-04-29 21:37:22 +00:00
|
|
|
"""Generalized Matrix Learning Vector Quantization."""
|
|
|
|
def __init__(self, hparams, **kwargs):
|
|
|
|
super().__init__(hparams, **kwargs)
|
2021-05-17 15:00:23 +00:00
|
|
|
self.backbone = torch.nn.Linear(self.hparams.input_dim,
|
|
|
|
self.hparams.latent_dim,
|
|
|
|
bias=False)
|
2021-05-09 18:53:31 +00:00
|
|
|
|
2021-05-07 13:24:47 +00:00
|
|
|
@property
|
|
|
|
def omega_matrix(self):
|
2021-05-17 15:00:23 +00:00
|
|
|
return self.backbone.weight.detach().cpu()
|
2021-05-07 13:24:47 +00:00
|
|
|
|
|
|
|
@property
|
|
|
|
def lambda_matrix(self):
|
2021-05-17 15:00:23 +00:00
|
|
|
omega = self.backbone.weight # (latent_dim, input_dim)
|
2021-05-10 12:09:25 +00:00
|
|
|
lam = omega.T @ omega
|
2021-05-07 13:24:47 +00:00
|
|
|
return lam.detach().cpu()
|
|
|
|
|
|
|
|
def show_lambda(self):
|
|
|
|
import matplotlib.pyplot as plt
|
|
|
|
title = "Lambda matrix"
|
|
|
|
plt.figure(title)
|
|
|
|
plt.title(title)
|
|
|
|
plt.imshow(self.lambda_matrix, cmap="gray")
|
|
|
|
plt.axis("off")
|
|
|
|
plt.colorbar()
|
|
|
|
plt.show(block=True)
|
|
|
|
|
2021-05-20 15:36:00 +00:00
|
|
|
def _forward(self, x):
|
2021-04-29 21:37:22 +00:00
|
|
|
protos, _ = self.proto_layer()
|
2021-05-12 14:36:22 +00:00
|
|
|
x, protos = get_flat(x, protos)
|
2021-05-17 15:00:23 +00:00
|
|
|
latent_x = self.backbone(x)
|
2021-05-20 15:36:00 +00:00
|
|
|
self.backbone.requires_grad_(self.both_path_gradients)
|
2021-05-17 15:00:23 +00:00
|
|
|
latent_protos = self.backbone(protos)
|
2021-05-20 15:36:00 +00:00
|
|
|
self.backbone.requires_grad_(True)
|
|
|
|
distances = self.distance_fn(latent_x, latent_protos)
|
|
|
|
return distances
|
2021-04-29 21:37:22 +00:00
|
|
|
|
2021-05-04 13:11:16 +00:00
|
|
|
|
2021-05-20 15:36:00 +00:00
|
|
|
class LVQMLN(SiameseGLVQ):
|
2021-04-29 21:37:22 +00:00
|
|
|
"""Learning Vector Quantization Multi-Layer Network.
|
|
|
|
|
|
|
|
GLVQ model that applies an arbitrary transformation on the inputs, BUT NOT
|
|
|
|
on the prototypes before computing the distances between them. This of
|
|
|
|
course, means that the prototypes no longer live the input space, but
|
|
|
|
rather in the embedding space.
|
|
|
|
|
|
|
|
"""
|
2021-05-20 15:36:00 +00:00
|
|
|
def _forward(self, x):
|
2021-04-29 21:37:22 +00:00
|
|
|
latent_protos, _ = self.proto_layer()
|
|
|
|
latent_x = self.backbone(x)
|
2021-05-20 15:36:00 +00:00
|
|
|
distances = self.distance_fn(latent_x, latent_protos)
|
|
|
|
return distances
|
2021-04-29 21:37:22 +00:00
|
|
|
|
2021-05-17 15:00:23 +00:00
|
|
|
|
2021-05-27 15:40:16 +00:00
|
|
|
class CELVQ(GLVQ):
|
|
|
|
"""Cross-Entropy Learning Vector Quantization."""
|
|
|
|
def __init__(self, hparams, **kwargs):
|
|
|
|
super().__init__(hparams, **kwargs)
|
|
|
|
self.loss = torch.nn.CrossEntropyLoss()
|
|
|
|
|
|
|
|
def shared_step(self, batch, batch_idx, optimizer_idx=None):
|
|
|
|
x, y = batch
|
|
|
|
out = self._forward(x) # [None, num_protos]
|
|
|
|
plabels = self.proto_layer.component_labels
|
|
|
|
probs = -1.0 * stratified_min(out, plabels) # [None, num_classes]
|
|
|
|
batch_loss = self.loss(out, y.long())
|
|
|
|
loss = batch_loss.sum(dim=0)
|
|
|
|
return out, loss
|
|
|
|
|
|
|
|
|
2021-05-18 17:49:16 +00:00
|
|
|
class GLVQ1(GLVQ):
|
2021-05-21 13:42:45 +00:00
|
|
|
"""Generalized Learning Vector Quantization 1."""
|
2021-05-17 15:00:23 +00:00
|
|
|
def __init__(self, hparams, **kwargs):
|
|
|
|
super().__init__(hparams, **kwargs)
|
|
|
|
self.loss = lvq1_loss
|
|
|
|
self.optimizer = torch.optim.SGD
|
|
|
|
|
|
|
|
|
2021-05-18 17:49:16 +00:00
|
|
|
class GLVQ21(GLVQ):
|
2021-05-21 13:42:45 +00:00
|
|
|
"""Generalized Learning Vector Quantization 2.1."""
|
2021-05-17 15:00:23 +00:00
|
|
|
def __init__(self, hparams, **kwargs):
|
|
|
|
super().__init__(hparams, **kwargs)
|
|
|
|
self.loss = lvq21_loss
|
|
|
|
self.optimizer = torch.optim.SGD
|
|
|
|
|
|
|
|
|
|
|
|
class ImageGLVQ(PrototypeImageModel, GLVQ):
|
|
|
|
"""GLVQ for training on image data.
|
|
|
|
|
|
|
|
GLVQ model that constrains the prototypes to the range [0, 1] by clamping
|
|
|
|
after updates.
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
class ImageGMLVQ(PrototypeImageModel, GMLVQ):
|
|
|
|
"""GMLVQ for training on image data.
|
|
|
|
|
|
|
|
GMLVQ model that constrains the prototypes to the range [0, 1] by clamping
|
|
|
|
after updates.
|
|
|
|
|
|
|
|
"""
|