2021-04-21 19:35:52 +00:00
|
|
|
import argparse
|
|
|
|
|
2021-04-21 12:51:34 +00:00
|
|
|
import pytorch_lightning as pl
|
|
|
|
import torch
|
2021-04-21 17:16:57 +00:00
|
|
|
import torchmetrics
|
2021-04-21 12:51:34 +00:00
|
|
|
from prototorch.functions.competitions import wtac
|
|
|
|
from prototorch.functions.distances import euclidean_distance
|
|
|
|
from prototorch.functions.initializers import get_initializer
|
|
|
|
from prototorch.functions.losses import glvq_loss
|
|
|
|
from prototorch.modules.prototypes import Prototypes1D
|
|
|
|
|
|
|
|
|
|
|
|
class GLVQ(pl.LightningModule):
|
|
|
|
"""Generalized Learning Vector Quantization."""
|
2021-04-21 19:35:52 +00:00
|
|
|
def __init__(self, hparams, input_dim, nclasses, **kwargs):
|
2021-04-21 12:51:34 +00:00
|
|
|
super().__init__()
|
2021-04-21 17:16:57 +00:00
|
|
|
self.lr = hparams.lr
|
2021-04-21 19:35:52 +00:00
|
|
|
self.hparams = hparams
|
|
|
|
# self.save_hyperparameters(
|
|
|
|
# "lr",
|
|
|
|
# "prototypes_per_class",
|
|
|
|
# "prototype_initializer",
|
|
|
|
# )
|
|
|
|
self.proto_layer = Prototypes1D(
|
|
|
|
input_dim=input_dim,
|
|
|
|
nclasses=nclasses,
|
|
|
|
prototypes_per_class=hparams.prototypes_per_class,
|
|
|
|
prototype_initializer=hparams.prototype_initializer,
|
|
|
|
**kwargs)
|
2021-04-21 17:16:57 +00:00
|
|
|
self.train_acc = torchmetrics.Accuracy()
|
2021-04-21 12:51:34 +00:00
|
|
|
|
|
|
|
@property
|
|
|
|
def prototypes(self):
|
|
|
|
return self.proto_layer.prototypes.detach().numpy()
|
|
|
|
|
|
|
|
@property
|
|
|
|
def prototype_labels(self):
|
|
|
|
return self.proto_layer.prototype_labels.detach().numpy()
|
|
|
|
|
2021-04-21 19:35:52 +00:00
|
|
|
def configure_optimizers(self):
|
|
|
|
optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
|
|
|
|
return optimizer
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def add_model_specific_args(parent_parser):
|
|
|
|
parser = argparse.ArgumentParser(parents=[parent_parser],
|
|
|
|
add_help=False)
|
|
|
|
parser.add_argument("--epochs", type=int, default=1)
|
|
|
|
parser.add_argument("--lr", type=float, default=1e-2)
|
|
|
|
parser.add_argument("--batch_size", type=int, default=32)
|
|
|
|
parser.add_argument("--prototypes_per_class", type=int, default=1)
|
|
|
|
parser.add_argument("--prototype_initializer",
|
|
|
|
type=str,
|
|
|
|
default="zeros")
|
|
|
|
return parser
|
|
|
|
|
2021-04-21 12:51:34 +00:00
|
|
|
def forward(self, x):
|
|
|
|
protos = self.proto_layer.prototypes
|
|
|
|
dis = euclidean_distance(x, protos)
|
|
|
|
return dis
|
|
|
|
|
|
|
|
def training_step(self, train_batch, batch_idx):
|
|
|
|
x, y = train_batch
|
|
|
|
x = x.view(x.size(0), -1)
|
|
|
|
dis = self(x)
|
|
|
|
plabels = self.proto_layer.prototype_labels
|
|
|
|
mu = glvq_loss(dis, y, prototype_labels=plabels)
|
|
|
|
loss = mu.sum(dim=0)
|
|
|
|
self.log("train_loss", loss)
|
2021-04-21 17:16:57 +00:00
|
|
|
with torch.no_grad():
|
|
|
|
preds = wtac(dis, plabels)
|
|
|
|
# self.train_acc.update(preds.int(), y.int())
|
2021-04-21 19:35:52 +00:00
|
|
|
self.train_acc(
|
|
|
|
preds.int(),
|
|
|
|
y.int()) # FloatTensors are assumed to be class probabilities
|
|
|
|
self.log("Training Accuracy",
|
|
|
|
self.train_acc,
|
|
|
|
on_step=False,
|
|
|
|
on_epoch=True)
|
2021-04-21 12:51:34 +00:00
|
|
|
return loss
|
|
|
|
|
2021-04-21 17:16:57 +00:00
|
|
|
# def training_epoch_end(self, outs):
|
|
|
|
# # Calling `self.train_acc.compute()` is
|
|
|
|
# # automatically done by setting `on_epoch=True` when logging in `self.training_step(...)`
|
|
|
|
# self.log("train_acc_epoch", self.train_acc.compute())
|
|
|
|
|
2021-04-21 12:51:34 +00:00
|
|
|
def predict(self, x):
|
|
|
|
with torch.no_grad():
|
2021-04-21 17:16:57 +00:00
|
|
|
# model.eval() # ?!
|
2021-04-21 12:51:34 +00:00
|
|
|
d = self(x)
|
|
|
|
plabels = self.proto_layer.prototype_labels
|
|
|
|
y_pred = wtac(d, plabels)
|
|
|
|
return y_pred.numpy()
|
|
|
|
|
|
|
|
|
|
|
|
class ImageGLVQ(GLVQ):
|
|
|
|
"""GLVQ model that constrains the prototypes to the range [0, 1] by
|
|
|
|
clamping after updates.
|
|
|
|
"""
|
|
|
|
def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
|
|
|
|
self.proto_layer.prototypes.data.clamp_(0., 1.)
|