prototorch_models/prototorch/models/glvq.py

273 lines
9.1 KiB
Python
Raw Normal View History

2021-04-21 12:51:34 +00:00
import torch
import torchmetrics
from prototorch.components import LabeledComponents
2021-05-04 18:56:16 +00:00
from prototorch.functions.activations import get_activation
2021-04-21 12:51:34 +00:00
from prototorch.functions.competitions import wtac
2021-05-06 16:42:06 +00:00
from prototorch.functions.distances import (euclidean_distance, omega_distance,
squared_euclidean_distance)
2021-04-21 12:51:34 +00:00
from prototorch.functions.losses import glvq_loss
from .abstract import AbstractPrototypeModel
2021-04-21 12:51:34 +00:00
class GLVQ(AbstractPrototypeModel):
2021-04-21 12:51:34 +00:00
"""Generalized Learning Vector Quantization."""
def __init__(self, hparams, **kwargs):
2021-04-21 12:51:34 +00:00
super().__init__()
2021-04-27 13:38:57 +00:00
self.save_hyperparameters(hparams)
2021-04-27 13:38:57 +00:00
# Default Values
self.hparams.setdefault("distance", euclidean_distance)
2021-05-03 11:20:49 +00:00
self.hparams.setdefault("optimizer", torch.optim.Adam)
2021-05-04 18:56:16 +00:00
self.hparams.setdefault("transfer_function", "identity")
self.hparams.setdefault("transfer_beta", 10.0)
2021-04-27 13:38:57 +00:00
self.proto_layer = LabeledComponents(
labels=(self.hparams.nclasses, self.hparams.prototypes_per_class),
initializer=self.hparams.prototype_initializer)
2021-04-27 13:38:57 +00:00
2021-05-04 18:56:16 +00:00
self.transfer_function = get_activation(self.hparams.transfer_function)
self.train_acc = torchmetrics.Accuracy()
2021-04-21 12:51:34 +00:00
@property
def prototype_labels(self):
2021-05-06 16:42:06 +00:00
return self.proto_layer.component_labels.detach().cpu()
2021-04-21 19:35:52 +00:00
2021-04-21 12:51:34 +00:00
def forward(self, x):
protos, _ = self.proto_layer()
2021-04-27 13:38:57 +00:00
dis = self.hparams.distance(x, protos)
2021-04-21 12:51:34 +00:00
return dis
2021-05-03 11:20:49 +00:00
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
2021-04-21 12:51:34 +00:00
x, y = train_batch
2021-04-29 21:37:22 +00:00
x = x.view(x.size(0), -1) # flatten
2021-04-21 12:51:34 +00:00
dis = self(x)
plabels = self.proto_layer.component_labels
2021-04-21 12:51:34 +00:00
mu = glvq_loss(dis, y, prototype_labels=plabels)
2021-05-04 18:56:16 +00:00
batch_loss = self.transfer_function(mu,
beta=self.hparams.transfer_beta)
loss = batch_loss.sum(dim=0)
2021-04-29 21:37:22 +00:00
# Compute training accuracy
with torch.no_grad():
preds = wtac(dis, plabels)
2021-05-06 16:02:01 +00:00
self.train_acc(preds.int(), y.int())
2021-04-29 21:37:22 +00:00
# `.int()` because FloatTensors are assumed to be class probabilities
2021-04-21 12:51:34 +00:00
2021-04-29 21:37:22 +00:00
# Logging
self.log("train_loss", loss)
self.log("acc",
self.train_acc,
on_step=False,
on_epoch=True,
prog_bar=True,
logger=True)
return loss
2021-04-21 12:51:34 +00:00
def predict(self, x):
2021-04-27 12:35:17 +00:00
# model.eval() # ?!
2021-04-21 12:51:34 +00:00
with torch.no_grad():
d = self(x)
plabels = self.proto_layer.component_labels
2021-04-21 12:51:34 +00:00
y_pred = wtac(d, plabels)
return y_pred.numpy()
class ImageGLVQ(GLVQ):
2021-04-27 12:35:17 +00:00
"""GLVQ for training on image data.
2021-04-29 21:37:22 +00:00
GLVQ model that constrains the prototypes to the range [0, 1] by clamping
after updates.
2021-04-21 12:51:34 +00:00
"""
def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
self.proto_layer.components.data.clamp_(0.0, 1.0)
2021-04-27 12:35:17 +00:00
class SiameseGLVQ(GLVQ):
"""GLVQ in a Siamese setting.
GLVQ model that applies an arbitrary transformation on the inputs and the
prototypes before computing the distances between them. The weights in the
transformation pipeline are only learned from the inputs.
2021-04-29 21:37:22 +00:00
2021-04-27 12:35:17 +00:00
"""
def __init__(self,
hparams,
backbone_module=torch.nn.Identity,
backbone_params={},
2021-05-04 13:11:16 +00:00
sync=True,
2021-04-27 12:35:17 +00:00
**kwargs):
super().__init__(hparams, **kwargs)
self.backbone = backbone_module(**backbone_params)
self.backbone_dependent = backbone_module(
**backbone_params).requires_grad_(False)
2021-05-04 13:11:16 +00:00
self.sync = sync
2021-04-27 12:35:17 +00:00
def sync_backbones(self):
master_state = self.backbone.state_dict()
self.backbone_dependent.load_state_dict(master_state, strict=True)
2021-05-03 11:20:49 +00:00
def configure_optimizers(self):
optim = self.hparams.optimizer
proto_opt = optim(self.proto_layer.parameters(),
lr=self.hparams.proto_lr)
2021-05-03 14:09:22 +00:00
if list(self.backbone.parameters()):
# only add an optimizer is the backbone has trainable parameters
# otherwise, the next line fails
bb_opt = optim(self.backbone.parameters(), lr=self.hparams.bb_lr)
return proto_opt, bb_opt
else:
return proto_opt
2021-05-03 11:20:49 +00:00
2021-04-27 12:35:17 +00:00
def forward(self, x):
2021-05-04 13:11:16 +00:00
if self.sync:
self.sync_backbones()
protos, _ = self.proto_layer()
2021-04-27 12:35:17 +00:00
latent_x = self.backbone(x)
latent_protos = self.backbone_dependent(protos)
dis = euclidean_distance(latent_x, latent_protos)
return dis
def predict_latent(self, x):
2021-04-29 21:37:22 +00:00
"""Predict `x` assuming it is already embedded in the latent space.
Only the prototypes are embedded in the latent space using the
backbone.
"""
2021-04-27 12:35:17 +00:00
# model.eval() # ?!
with torch.no_grad():
protos, plabels = self.proto_layer()
2021-04-27 12:35:17 +00:00
latent_protos = self.backbone_dependent(protos)
d = euclidean_distance(x, latent_protos)
y_pred = wtac(d, plabels)
return y_pred.numpy()
2021-04-29 21:37:22 +00:00
2021-05-06 16:42:06 +00:00
class GRLVQ(GLVQ):
"""Generalized Relevance Learning Vector Quantization."""
def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs)
self.relevances = torch.nn.parameter.Parameter(
torch.ones(self.hparams.input_dim))
def forward(self, x):
protos, _ = self.proto_layer()
dis = omega_distance(x, protos, torch.diag(self.relevances))
return dis
def backbone(self, x):
return x @ torch.diag(self.relevances)
@property
def relevance_profile(self):
return self.relevances.detach().cpu()
def predict_latent(self, x):
"""Predict `x` assuming it is already embedded in the latent space.
Only the prototypes are embedded in the latent space using the
backbone.
"""
# model.eval() # ?!
with torch.no_grad():
protos, plabels = self.proto_layer()
latent_protos = protos @ torch.diag(self.relevances)
d = squared_euclidean_distance(x, latent_protos)
y_pred = wtac(d, plabels)
return y_pred.numpy()
2021-04-29 21:37:22 +00:00
class GMLVQ(GLVQ):
"""Generalized Matrix Learning Vector Quantization."""
def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs)
self.omega_layer = torch.nn.Linear(self.hparams.input_dim,
2021-05-04 13:11:16 +00:00
self.hparams.latent_dim,
2021-04-29 21:37:22 +00:00
bias=False)
# Namespace hook for the visualization callbacks to work
self.backbone = self.omega_layer
2021-05-07 13:24:47 +00:00
@property
def omega_matrix(self):
return self.omega_layer.weight.detach().cpu()
@property
def lambda_matrix(self):
omega = self.omega_layer.weight # (latent_dim, input_dim)
lam = omega.T @ omega
2021-05-07 13:24:47 +00:00
return lam.detach().cpu()
def show_lambda(self):
import matplotlib.pyplot as plt
title = "Lambda matrix"
plt.figure(title)
plt.title(title)
plt.imshow(self.lambda_matrix, cmap="gray")
plt.axis("off")
plt.colorbar()
plt.show(block=True)
2021-04-29 21:37:22 +00:00
def forward(self, x):
protos, _ = self.proto_layer()
latent_x = self.omega_layer(x)
latent_protos = self.omega_layer(protos)
dis = squared_euclidean_distance(latent_x, latent_protos)
2021-04-29 21:37:22 +00:00
return dis
2021-05-04 13:11:16 +00:00
def predict_latent(self, x):
"""Predict `x` assuming it is already embedded in the latent space.
Only the prototypes are embedded in the latent space using the
backbone.
"""
# model.eval() # ?!
with torch.no_grad():
protos, plabels = self.proto_layer()
latent_protos = self.omega_layer(protos)
d = squared_euclidean_distance(x, latent_protos)
y_pred = wtac(d, plabels)
return y_pred.numpy()
2021-04-29 21:37:22 +00:00
class LVQMLN(GLVQ):
"""Learning Vector Quantization Multi-Layer Network.
GLVQ model that applies an arbitrary transformation on the inputs, BUT NOT
on the prototypes before computing the distances between them. This of
course, means that the prototypes no longer live the input space, but
rather in the embedding space.
"""
def __init__(self,
hparams,
backbone_module=torch.nn.Identity,
backbone_params={},
**kwargs):
super().__init__(hparams, **kwargs)
self.backbone = backbone_module(**backbone_params)
with torch.no_grad():
protos = self.backbone(self.proto_layer()[0])
self.proto_layer.load_state_dict({"_components": protos}, strict=False)
2021-04-29 21:37:22 +00:00
def forward(self, x):
latent_protos, _ = self.proto_layer()
latent_x = self.backbone(x)
dis = euclidean_distance(latent_x, latent_protos)
return dis
def predict_latent(self, x):
"""Predict `x` assuming it is already embedded in the latent space."""
with torch.no_grad():
latent_protos, plabels = self.proto_layer()
d = euclidean_distance(x, latent_protos)
y_pred = wtac(d, plabels)
return y_pred.numpy()