prototorch_models/prototorch/models/glvq.py

135 lines
4.3 KiB
Python
Raw Normal View History

2021-04-21 12:51:34 +00:00
import pytorch_lightning as pl
import torch
import torchmetrics
2021-04-23 15:27:47 +00:00
2021-04-21 12:51:34 +00:00
from prototorch.functions.competitions import wtac
from prototorch.functions.distances import euclidean_distance
from prototorch.functions.losses import glvq_loss
from prototorch.modules.prototypes import Prototypes1D
class GLVQ(pl.LightningModule):
"""Generalized Learning Vector Quantization."""
def __init__(self, hparams, **kwargs):
2021-04-21 12:51:34 +00:00
super().__init__()
2021-04-27 13:38:57 +00:00
self.save_hyperparameters(hparams)
2021-04-27 13:38:57 +00:00
# Default Values
self.hparams.setdefault("distance", euclidean_distance)
2021-04-21 19:35:52 +00:00
self.proto_layer = Prototypes1D(
input_dim=self.hparams.input_dim,
nclasses=self.hparams.nclasses,
prototypes_per_class=self.hparams.prototypes_per_class,
prototype_initializer=self.hparams.prototype_initializer,
2021-04-21 19:35:52 +00:00
**kwargs)
2021-04-27 13:38:57 +00:00
self.train_acc = torchmetrics.Accuracy()
2021-04-21 12:51:34 +00:00
@property
def prototypes(self):
return self.proto_layer.prototypes.detach().numpy()
@property
def prototype_labels(self):
return self.proto_layer.prototype_labels.detach().numpy()
2021-04-21 19:35:52 +00:00
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.lr)
2021-04-21 19:35:52 +00:00
return optimizer
2021-04-21 12:51:34 +00:00
def forward(self, x):
protos = self.proto_layer.prototypes
2021-04-27 13:38:57 +00:00
dis = self.hparams.distance(x, protos)
2021-04-21 12:51:34 +00:00
return dis
def training_step(self, train_batch, batch_idx):
x, y = train_batch
x = x.view(x.size(0), -1)
dis = self(x)
plabels = self.proto_layer.prototype_labels
mu = glvq_loss(dis, y, prototype_labels=plabels)
loss = mu.sum(dim=0)
self.log("train_loss", loss)
with torch.no_grad():
preds = wtac(dis, plabels)
# self.train_acc.update(preds.int(), y.int())
2021-04-21 19:35:52 +00:00
self.train_acc(
preds.int(),
y.int()) # FloatTensors are assumed to be class probabilities
2021-04-23 15:27:47 +00:00
self.log(
"acc",
self.train_acc,
on_step=False,
on_epoch=True,
prog_bar=True,
logger=True,
)
2021-04-21 12:51:34 +00:00
return loss
# def training_epoch_end(self, outs):
# # Calling `self.train_acc.compute()` is
# # automatically done by setting `on_epoch=True` when logging in `self.training_step(...)`
# self.log("train_acc_epoch", self.train_acc.compute())
2021-04-21 12:51:34 +00:00
def predict(self, x):
2021-04-27 12:35:17 +00:00
# model.eval() # ?!
2021-04-21 12:51:34 +00:00
with torch.no_grad():
d = self(x)
plabels = self.proto_layer.prototype_labels
y_pred = wtac(d, plabels)
return y_pred.numpy()
class ImageGLVQ(GLVQ):
2021-04-27 12:35:17 +00:00
"""GLVQ for training on image data.
GLVQ model that constrains the prototypes to the range [0, 1] by
2021-04-21 12:51:34 +00:00
clamping after updates.
"""
def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
2021-04-23 15:27:47 +00:00
self.proto_layer.prototypes.data.clamp_(0.0, 1.0)
2021-04-27 12:35:17 +00:00
class SiameseGLVQ(GLVQ):
"""GLVQ in a Siamese setting.
GLVQ model that applies an arbitrary transformation on the inputs and the
prototypes before computing the distances between them. The weights in the
transformation pipeline are only learned from the inputs.
"""
def __init__(self,
hparams,
backbone_module=torch.nn.Identity,
backbone_params={},
**kwargs):
super().__init__(hparams, **kwargs)
self.backbone = backbone_module(**backbone_params)
self.backbone_dependent = backbone_module(
**backbone_params).requires_grad_(False)
def sync_backbones(self):
master_state = self.backbone.state_dict()
self.backbone_dependent.load_state_dict(master_state, strict=True)
def forward(self, x):
self.sync_backbones()
protos = self.proto_layer.prototypes
latent_x = self.backbone(x)
latent_protos = self.backbone_dependent(protos)
dis = euclidean_distance(latent_x, latent_protos)
return dis
def predict_latent(self, x):
# model.eval() # ?!
with torch.no_grad():
protos = self.proto_layer.prototypes
latent_protos = self.backbone_dependent(protos)
d = euclidean_distance(x, latent_protos)
plabels = self.proto_layer.prototype_labels
y_pred = wtac(d, plabels)
return y_pred.numpy()