[REFACTOR] Major cleanup

This commit is contained in:
Jensun Ravichandran
2021-06-04 22:20:32 +02:00
parent 20471bfb1c
commit 016fcb4060
11 changed files with 481 additions and 399 deletions

View File

@@ -1,101 +1,40 @@
"""Models based on the GLVQ framework."""
import torch
import torchmetrics
from prototorch.components import LabeledComponents
from prototorch.functions.activations import get_activation
from prototorch.functions.competitions import wtac
from prototorch.functions.distances import (
euclidean_distance,
lomega_distance,
omega_distance,
squared_euclidean_distance,
)
from prototorch.functions.helper import get_flat
from prototorch.functions.losses import glvq_loss, lvq1_loss, lvq21_loss
from prototorch.modules import LambdaLayer
from prototorch.modules import LambdaLayer, LossLayer
from torch.nn.parameter import Parameter
from .abstract import AbstractPrototypeModel, PrototypeImageModel
from .abstract import ImagePrototypesMixin, SupervisedPrototypeModel
class GLVQ(AbstractPrototypeModel):
class GLVQ(SupervisedPrototypeModel):
"""Generalized Learning Vector Quantization."""
def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs)
super().__init__()
# Hyperparameters
self.save_hyperparameters(hparams)
# Defaults
# Default hparams
self.hparams.setdefault("transfer_fn", "identity")
self.hparams.setdefault("transfer_beta", 10.0)
self.hparams.setdefault("lr", 0.01)
distance_fn = kwargs.get("distance_fn", euclidean_distance)
transfer_fn = get_activation(self.hparams.transfer_fn)
# Layers
prototype_initializer = kwargs.get("prototype_initializer", None)
self.proto_layer = LabeledComponents(
distribution=self.hparams.distribution,
initializer=prototype_initializer)
self.distance_layer = LambdaLayer(distance_fn)
transfer_fn = get_activation(self.hparams.transfer_fn)
self.transfer_layer = LambdaLayer(transfer_fn)
self.loss = LambdaLayer(glvq_loss)
# Loss
self.loss = LossLayer(glvq_loss)
# Prototype metrics
self.initialize_prototype_win_ratios()
self.optimizer = kwargs.get("optimizer", torch.optim.Adam)
self.lr_scheduler = kwargs.get("lr_scheduler", None)
self.lr_scheduler_kwargs = kwargs.get("lr_scheduler_kwargs", dict())
@property
def prototype_labels(self):
return self.proto_layer.component_labels.detach().cpu()
@property
def num_classes(self):
return len(self.proto_layer.distribution)
def _forward(self, x):
protos, _ = self.proto_layer()
distances = self.distance_layer(x, protos)
return distances
def forward(self, x):
distances = self._forward(x)
y_pred = self.predict_from_distances(distances)
y_pred = torch.eye(self.num_classes, device=self.device)[y_pred.long()]
return y_pred
def predict_from_distances(self, distances):
with torch.no_grad():
plabels = self.proto_layer.component_labels
y_pred = wtac(distances, plabels)
return y_pred
def predict(self, x):
with torch.no_grad():
distances = self._forward(x)
y_pred = self.predict_from_distances(distances)
return y_pred
def log_acc(self, distances, targets, tag):
preds = self.predict_from_distances(distances)
accuracy = torchmetrics.functional.accuracy(preds.int(), targets.int())
# `.int()` because FloatTensors are assumed to be class probabilities
self.log(tag,
accuracy,
on_step=False,
on_epoch=True,
prog_bar=True,
logger=True)
def initialize_prototype_win_ratios(self):
self.register_buffer(
"prototype_win_ratios",
@@ -121,7 +60,7 @@ class GLVQ(AbstractPrototypeModel):
def shared_step(self, batch, batch_idx, optimizer_idx=None):
x, y = batch
out = self._forward(x)
out = self.compute_distances(x)
plabels = self.proto_layer.component_labels
mu = self.loss(out, y, prototype_labels=plabels)
batch_loss = self.transfer_layer(mu, beta=self.hparams.transfer_beta)
@@ -158,18 +97,6 @@ class GLVQ(AbstractPrototypeModel):
# def predict_step(self, batch, batch_idx, dataloader_idx=None):
# pass
def add_prototypes(self, initializer, distribution):
self.proto_layer.add_components(initializer, distribution)
self.trainer.accelerator_backend.setup_optimizers(self.trainer)
def remove_prototypes(self, indices):
self.proto_layer.remove_components(indices)
self.trainer.accelerator_backend.setup_optimizers(self.trainer)
def __repr__(self):
super_repr = super().__repr__()
return f"{super_repr}"
class SiameseGLVQ(GLVQ):
"""GLVQ in a Siamese setting.
@@ -212,7 +139,7 @@ class SiameseGLVQ(GLVQ):
else:
return optimizer
def _forward(self, x):
def compute_distances(self, x):
protos, _ = self.proto_layer()
latent_x = self.backbone(x)
self.backbone.requires_grad_(self.both_path_gradients)
@@ -256,7 +183,7 @@ class GRLVQ(SiameseGLVQ):
def relevance_profile(self):
return self.relevances.detach().cpu()
def _forward(self, x):
def compute_distances(self, x):
protos, _ = self.proto_layer()
distances = self.distance_layer(x, protos, torch.diag(self.relevances))
return distances
@@ -285,7 +212,7 @@ class SiameseGMLVQ(SiameseGLVQ):
lam = omega.T @ omega
return lam.detach().cpu()
def _forward(self, x):
def compute_distances(self, x):
protos, _ = self.proto_layer()
x, protos = get_flat(x, protos)
latent_x = self.backbone(x)
@@ -305,7 +232,7 @@ class LVQMLN(SiameseGLVQ):
rather in the embedding space.
"""
def _forward(self, x):
def compute_distances(self, x):
latent_protos, _ = self.proto_layer()
latent_x = self.backbone(x)
distances = self.distance_layer(latent_x, latent_protos)
@@ -327,7 +254,7 @@ class GMLVQ(GLVQ):
device=self.device)
self.register_parameter("_omega", Parameter(omega))
def _forward(self, x):
def compute_distances(self, x):
protos, _ = self.proto_layer()
distances = self.distance_layer(x, protos, self._omega)
return distances
@@ -355,7 +282,7 @@ class GLVQ1(GLVQ):
"""Generalized Learning Vector Quantization 1."""
def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs)
self.loss = lvq1_loss
self.loss = LossLayer(lvq1_loss)
self.optimizer = torch.optim.SGD
@@ -363,11 +290,11 @@ class GLVQ21(GLVQ):
"""Generalized Learning Vector Quantization 2.1."""
def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs)
self.loss = lvq21_loss
self.loss = LossLayer(lvq21_loss)
self.optimizer = torch.optim.SGD
class ImageGLVQ(PrototypeImageModel, GLVQ):
class ImageGLVQ(ImagePrototypesMixin, GLVQ):
"""GLVQ for training on image data.
GLVQ model that constrains the prototypes to the range [0, 1] by clamping
@@ -376,7 +303,7 @@ class ImageGLVQ(PrototypeImageModel, GLVQ):
"""
class ImageGMLVQ(PrototypeImageModel, GMLVQ):
class ImageGMLVQ(ImagePrototypesMixin, GMLVQ):
"""GMLVQ for training on image data.
GMLVQ model that constrains the prototypes to the range [0, 1] by clamping