prototorch_models/prototorch/models/glvq.py

341 lines
11 KiB
Python
Raw Normal View History

2021-04-21 12:51:34 +00:00
import torch
import torchmetrics
from prototorch.components import LabeledComponents
2021-05-04 18:56:16 +00:00
from prototorch.functions.activations import get_activation
2021-04-21 12:51:34 +00:00
from prototorch.functions.competitions import wtac
2021-05-06 16:42:06 +00:00
from prototorch.functions.distances import (euclidean_distance, omega_distance,
sed)
from prototorch.functions.helper import get_flat
2021-05-18 17:49:16 +00:00
from prototorch.functions.losses import (_get_dp_dm, _get_matcher, glvq_loss,
lvq1_loss, lvq21_loss)
2021-04-21 12:51:34 +00:00
from .abstract import (AbstractPrototypeModel, PrototypeImageModel,
SiamesePrototypeModel)
class GLVQ(AbstractPrototypeModel):
"""Generalized Learning Vector Quantization."""
from .abstract import AbstractPrototypeModel, PrototypeImageModel
2021-04-21 12:51:34 +00:00
class GLVQ(AbstractPrototypeModel):
2021-04-21 12:51:34 +00:00
"""Generalized Learning Vector Quantization."""
def __init__(self, hparams, **kwargs):
2021-04-21 12:51:34 +00:00
super().__init__()
2021-04-27 13:38:57 +00:00
self.save_hyperparameters(hparams)
2021-04-27 13:38:57 +00:00
self.distance_fn = kwargs.get("distance_fn", euclidean_distance)
2021-05-11 14:13:00 +00:00
self.optimizer = kwargs.get("optimizer", torch.optim.Adam)
prototype_initializer = kwargs.get("prototype_initializer", None)
2021-05-11 14:13:00 +00:00
2021-04-27 13:38:57 +00:00
# Default Values
2021-05-19 14:30:19 +00:00
self.hparams.setdefault("transfer_fn", "identity")
2021-05-04 18:56:16 +00:00
self.hparams.setdefault("transfer_beta", 10.0)
2021-05-18 17:49:16 +00:00
self.hparams.setdefault("lr", 0.01)
2021-04-27 13:38:57 +00:00
self.proto_layer = LabeledComponents(
2021-05-11 14:13:00 +00:00
distribution=self.hparams.distribution,
initializer=prototype_initializer)
2021-04-27 13:38:57 +00:00
2021-05-19 14:30:19 +00:00
self.transfer_fn = get_activation(self.hparams.transfer_fn)
self.acc_metric = torchmetrics.Accuracy()
2021-04-21 12:51:34 +00:00
2021-05-11 11:26:13 +00:00
self.loss = glvq_loss
2021-04-21 12:51:34 +00:00
@property
def prototype_labels(self):
2021-05-06 16:42:06 +00:00
return self.proto_layer.component_labels.detach().cpu()
2021-04-21 19:35:52 +00:00
@property
def num_classes(self):
return len(self.proto_layer.distribution)
def _forward(self, x):
protos, _ = self.proto_layer()
2021-05-19 14:57:51 +00:00
distances = self.distance_fn(x, protos)
return distances
2021-04-21 12:51:34 +00:00
def forward(self, x):
distances = self._forward(x)
y_pred = self.predict_from_distances(distances)
y_pred = torch.eye(self.num_classes, device=self.device)[y_pred.int()]
return y_pred
2021-05-20 12:40:02 +00:00
def predict_from_distances(self, distances):
with torch.no_grad():
plabels = self.proto_layer.component_labels
y_pred = wtac(distances, plabels)
return y_pred
def predict(self, x):
with torch.no_grad():
distances = self._forward(x)
2021-05-20 12:40:02 +00:00
y_pred = self.predict_from_distances(distances)
return y_pred
2021-05-06 16:02:01 +00:00
2021-05-20 12:40:02 +00:00
def log_acc(self, distances, targets, tag):
preds = self.predict_from_distances(distances)
2021-05-19 14:30:19 +00:00
self.acc_metric(preds.int(), targets.int())
2021-04-29 21:37:22 +00:00
# `.int()` because FloatTensors are assumed to be class probabilities
2021-04-21 12:51:34 +00:00
2021-05-19 14:30:19 +00:00
self.log(tag,
self.acc_metric,
2021-04-29 21:37:22 +00:00
on_step=False,
on_epoch=True,
prog_bar=True,
logger=True)
2021-05-19 14:57:51 +00:00
def shared_step(self, batch, batch_idx, optimizer_idx=None):
x, y = batch
out = self._forward(x)
2021-05-18 17:49:16 +00:00
plabels = self.proto_layer.component_labels
2021-05-19 14:57:51 +00:00
mu = self.loss(out, y, prototype_labels=plabels)
batch_loss = self.transfer_fn(mu, beta=self.hparams.transfer_beta)
loss = batch_loss.sum(dim=0)
return out, loss
2021-05-18 17:49:16 +00:00
2021-05-19 14:57:51 +00:00
def training_step(self, batch, batch_idx, optimizer_idx=None):
out, train_loss = self.shared_step(batch, batch_idx, optimizer_idx)
self.log("train_loss", train_loss)
2021-05-19 14:57:51 +00:00
self.log_acc(out, batch[-1], tag="train_acc")
2021-05-19 14:30:19 +00:00
return train_loss
2021-05-19 14:57:51 +00:00
def validation_step(self, batch, batch_idx):
# `model.eval()` and `torch.no_grad()` handled by pl
out, val_loss = self.shared_step(batch, batch_idx)
2021-05-19 14:30:19 +00:00
self.log("val_loss", val_loss)
2021-05-19 14:57:51 +00:00
self.log_acc(out, batch[-1], tag="val_acc")
2021-05-19 14:30:19 +00:00
return val_loss
2021-05-19 14:57:51 +00:00
def test_step(self, batch, batch_idx):
# `model.eval()` and `torch.no_grad()` handled by pl
out, test_loss = self.shared_step(batch, batch_idx)
2021-05-20 12:03:31 +00:00
self.log_acc(out, batch[-1], tag="test_acc")
2021-05-20 12:20:23 +00:00
return test_loss
def test_epoch_end(self, outputs):
2021-05-20 12:40:02 +00:00
test_loss = 0.0
2021-05-20 12:20:23 +00:00
for batch_loss in outputs:
2021-05-20 12:40:02 +00:00
test_loss += batch_loss.item()
self.log("test_loss", test_loss)
2021-05-19 14:57:51 +00:00
# def predict_step(self, batch, batch_idx, dataloader_idx=None):
# pass
2021-05-18 17:49:16 +00:00
def __repr__(self):
super_repr = super().__repr__()
return f"{super_repr}"
2021-04-21 12:51:34 +00:00
class SiameseGLVQ(SiamesePrototypeModel, GLVQ):
2021-04-27 12:35:17 +00:00
"""GLVQ in a Siamese setting.
GLVQ model that applies an arbitrary transformation on the inputs and the
prototypes before computing the distances between them. The weights in the
transformation pipeline are only learned from the inputs.
2021-04-29 21:37:22 +00:00
2021-04-27 12:35:17 +00:00
"""
def __init__(self,
hparams,
backbone=torch.nn.Identity(),
both_path_gradients=False,
2021-04-27 12:35:17 +00:00
**kwargs):
super().__init__(hparams, **kwargs)
self.backbone = backbone
self.both_path_gradients = both_path_gradients
self.distance_fn = kwargs.get("distance_fn", sed)
2021-05-03 11:20:49 +00:00
def _forward(self, x):
protos, _ = self.proto_layer()
2021-04-27 12:35:17 +00:00
latent_x = self.backbone(x)
self.backbone.requires_grad_(self.both_path_gradients)
latent_protos = self.backbone(protos)
self.backbone.requires_grad_(True)
distances = self.distance_fn(latent_x, latent_protos)
return distances
2021-04-27 12:35:17 +00:00
2021-04-29 21:37:22 +00:00
class GRLVQ(SiameseGLVQ):
2021-05-06 16:42:06 +00:00
"""Generalized Relevance Learning Vector Quantization."""
def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs)
self.relevances = torch.nn.parameter.Parameter(
torch.ones(self.hparams.input_dim))
@property
def relevance_profile(self):
return self.relevances.detach().cpu()
def backbone(self, x):
"""Namespace hook for the visualization callbacks to work."""
return x @ torch.diag(self.relevances)
2021-05-06 16:42:06 +00:00
def _forward(self, x):
protos, _ = self.proto_layer()
distances = omega_distance(x, protos, torch.diag(self.relevances))
return distances
2021-05-06 16:42:06 +00:00
class GMLVQ(SiameseGLVQ):
2021-04-29 21:37:22 +00:00
"""Generalized Matrix Learning Vector Quantization."""
def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs)
self.backbone = torch.nn.Linear(self.hparams.input_dim,
self.hparams.latent_dim,
bias=False)
2021-05-07 13:24:47 +00:00
@property
def omega_matrix(self):
return self.backbone.weight.detach().cpu()
2021-05-07 13:24:47 +00:00
@property
def lambda_matrix(self):
omega = self.backbone.weight # (latent_dim, input_dim)
lam = omega.T @ omega
2021-05-07 13:24:47 +00:00
return lam.detach().cpu()
def show_lambda(self):
import matplotlib.pyplot as plt
title = "Lambda matrix"
plt.figure(title)
plt.title(title)
plt.imshow(self.lambda_matrix, cmap="gray")
plt.axis("off")
plt.colorbar()
plt.show(block=True)
def _forward(self, x):
2021-04-29 21:37:22 +00:00
protos, _ = self.proto_layer()
x, protos = get_flat(x, protos)
latent_x = self.backbone(x)
self.backbone.requires_grad_(self.both_path_gradients)
latent_protos = self.backbone(protos)
self.backbone.requires_grad_(True)
distances = self.distance_fn(latent_x, latent_protos)
return distances
2021-04-29 21:37:22 +00:00
2021-05-04 13:11:16 +00:00
class LVQMLN(SiameseGLVQ):
2021-04-29 21:37:22 +00:00
"""Learning Vector Quantization Multi-Layer Network.
GLVQ model that applies an arbitrary transformation on the inputs, BUT NOT
on the prototypes before computing the distances between them. This of
course, means that the prototypes no longer live the input space, but
rather in the embedding space.
"""
def _forward(self, x):
2021-04-29 21:37:22 +00:00
latent_protos, _ = self.proto_layer()
latent_x = self.backbone(x)
distances = self.distance_fn(latent_x, latent_protos)
return distances
2021-04-29 21:37:22 +00:00
2021-05-18 17:49:16 +00:00
class NonGradientGLVQ(GLVQ):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.automatic_optimization = False
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
raise NotImplementedError
class LVQ1(NonGradientGLVQ):
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
protos = self.proto_layer.components
plabels = self.proto_layer.component_labels
x, y = train_batch
dis = self._forward(x)
2021-05-18 17:49:16 +00:00
# TODO Vectorized implementation
for xi, yi in zip(x, y):
d = self(xi.view(1, -1))
preds = wtac(d, plabels)
w = d.argmin(1)
if yi == preds:
shift = xi - protos[w]
else:
shift = protos[w] - xi
updated_protos = protos + 0.0
updated_protos[w] = protos[w] + (self.hparams.lr * shift)
self.proto_layer.load_state_dict({"_components": updated_protos},
strict=False)
# Logging
2021-05-19 14:30:19 +00:00
self.log_acc(dis, y, tag="train_acc")
2021-05-18 17:49:16 +00:00
return None
class LVQ21(NonGradientGLVQ):
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
protos = self.proto_layer.components
plabels = self.proto_layer.component_labels
x, y = train_batch
dis = self._forward(x)
2021-05-18 17:49:16 +00:00
# TODO Vectorized implementation
for xi, yi in zip(x, y):
xi = xi.view(1, -1)
yi = yi.view(1, )
d = self(xi)
preds = wtac(d, plabels)
(dp, wp), (dn, wn) = _get_dp_dm(d, yi, plabels, with_indices=True)
shiftp = xi - protos[wp]
shiftn = protos[wn] - xi
updated_protos = protos + 0.0
updated_protos[wp] = protos[wp] + (self.hparams.lr * shiftp)
updated_protos[wn] = protos[wn] + (self.hparams.lr * shiftn)
self.proto_layer.load_state_dict({"_components": updated_protos},
strict=False)
# Logging
2021-05-19 14:30:19 +00:00
self.log_acc(dis, y, tag="train_acc")
2021-05-18 17:49:16 +00:00
return None
class MedianLVQ(NonGradientGLVQ):
...
class GLVQ1(GLVQ):
"""Learning Vector Quantization 1."""
def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs)
self.loss = lvq1_loss
self.optimizer = torch.optim.SGD
2021-05-18 17:49:16 +00:00
class GLVQ21(GLVQ):
"""Learning Vector Quantization 2.1."""
def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs)
self.loss = lvq21_loss
self.optimizer = torch.optim.SGD
class ImageGLVQ(PrototypeImageModel, GLVQ):
"""GLVQ for training on image data.
GLVQ model that constrains the prototypes to the range [0, 1] by clamping
after updates.
"""
pass
class ImageGMLVQ(PrototypeImageModel, GMLVQ):
"""GMLVQ for training on image data.
GMLVQ model that constrains the prototypes to the range [0, 1] by clamping
after updates.
"""
pass