[REFACTOR] Major cleanup
This commit is contained in:
@@ -1,101 +1,40 @@
|
||||
"""Models based on the GLVQ framework."""
|
||||
|
||||
import torch
|
||||
import torchmetrics
|
||||
from prototorch.components import LabeledComponents
|
||||
from prototorch.functions.activations import get_activation
|
||||
from prototorch.functions.competitions import wtac
|
||||
from prototorch.functions.distances import (
|
||||
euclidean_distance,
|
||||
lomega_distance,
|
||||
omega_distance,
|
||||
squared_euclidean_distance,
|
||||
)
|
||||
from prototorch.functions.helper import get_flat
|
||||
from prototorch.functions.losses import glvq_loss, lvq1_loss, lvq21_loss
|
||||
from prototorch.modules import LambdaLayer
|
||||
from prototorch.modules import LambdaLayer, LossLayer
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from .abstract import AbstractPrototypeModel, PrototypeImageModel
|
||||
from .abstract import ImagePrototypesMixin, SupervisedPrototypeModel
|
||||
|
||||
|
||||
class GLVQ(AbstractPrototypeModel):
|
||||
class GLVQ(SupervisedPrototypeModel):
|
||||
"""Generalized Learning Vector Quantization."""
|
||||
def __init__(self, hparams, **kwargs):
|
||||
super().__init__(hparams, **kwargs)
|
||||
|
||||
super().__init__()
|
||||
|
||||
# Hyperparameters
|
||||
self.save_hyperparameters(hparams)
|
||||
|
||||
# Defaults
|
||||
# Default hparams
|
||||
self.hparams.setdefault("transfer_fn", "identity")
|
||||
self.hparams.setdefault("transfer_beta", 10.0)
|
||||
self.hparams.setdefault("lr", 0.01)
|
||||
|
||||
distance_fn = kwargs.get("distance_fn", euclidean_distance)
|
||||
transfer_fn = get_activation(self.hparams.transfer_fn)
|
||||
|
||||
# Layers
|
||||
prototype_initializer = kwargs.get("prototype_initializer", None)
|
||||
self.proto_layer = LabeledComponents(
|
||||
distribution=self.hparams.distribution,
|
||||
initializer=prototype_initializer)
|
||||
|
||||
self.distance_layer = LambdaLayer(distance_fn)
|
||||
transfer_fn = get_activation(self.hparams.transfer_fn)
|
||||
self.transfer_layer = LambdaLayer(transfer_fn)
|
||||
self.loss = LambdaLayer(glvq_loss)
|
||||
|
||||
# Loss
|
||||
self.loss = LossLayer(glvq_loss)
|
||||
|
||||
# Prototype metrics
|
||||
self.initialize_prototype_win_ratios()
|
||||
|
||||
self.optimizer = kwargs.get("optimizer", torch.optim.Adam)
|
||||
self.lr_scheduler = kwargs.get("lr_scheduler", None)
|
||||
self.lr_scheduler_kwargs = kwargs.get("lr_scheduler_kwargs", dict())
|
||||
|
||||
@property
|
||||
def prototype_labels(self):
|
||||
return self.proto_layer.component_labels.detach().cpu()
|
||||
|
||||
@property
|
||||
def num_classes(self):
|
||||
return len(self.proto_layer.distribution)
|
||||
|
||||
def _forward(self, x):
|
||||
protos, _ = self.proto_layer()
|
||||
distances = self.distance_layer(x, protos)
|
||||
return distances
|
||||
|
||||
def forward(self, x):
|
||||
distances = self._forward(x)
|
||||
y_pred = self.predict_from_distances(distances)
|
||||
y_pred = torch.eye(self.num_classes, device=self.device)[y_pred.long()]
|
||||
return y_pred
|
||||
|
||||
def predict_from_distances(self, distances):
|
||||
with torch.no_grad():
|
||||
plabels = self.proto_layer.component_labels
|
||||
y_pred = wtac(distances, plabels)
|
||||
return y_pred
|
||||
|
||||
def predict(self, x):
|
||||
with torch.no_grad():
|
||||
distances = self._forward(x)
|
||||
y_pred = self.predict_from_distances(distances)
|
||||
return y_pred
|
||||
|
||||
def log_acc(self, distances, targets, tag):
|
||||
preds = self.predict_from_distances(distances)
|
||||
accuracy = torchmetrics.functional.accuracy(preds.int(), targets.int())
|
||||
# `.int()` because FloatTensors are assumed to be class probabilities
|
||||
|
||||
self.log(tag,
|
||||
accuracy,
|
||||
on_step=False,
|
||||
on_epoch=True,
|
||||
prog_bar=True,
|
||||
logger=True)
|
||||
|
||||
def initialize_prototype_win_ratios(self):
|
||||
self.register_buffer(
|
||||
"prototype_win_ratios",
|
||||
@@ -121,7 +60,7 @@ class GLVQ(AbstractPrototypeModel):
|
||||
|
||||
def shared_step(self, batch, batch_idx, optimizer_idx=None):
|
||||
x, y = batch
|
||||
out = self._forward(x)
|
||||
out = self.compute_distances(x)
|
||||
plabels = self.proto_layer.component_labels
|
||||
mu = self.loss(out, y, prototype_labels=plabels)
|
||||
batch_loss = self.transfer_layer(mu, beta=self.hparams.transfer_beta)
|
||||
@@ -158,18 +97,6 @@ class GLVQ(AbstractPrototypeModel):
|
||||
# def predict_step(self, batch, batch_idx, dataloader_idx=None):
|
||||
# pass
|
||||
|
||||
def add_prototypes(self, initializer, distribution):
|
||||
self.proto_layer.add_components(initializer, distribution)
|
||||
self.trainer.accelerator_backend.setup_optimizers(self.trainer)
|
||||
|
||||
def remove_prototypes(self, indices):
|
||||
self.proto_layer.remove_components(indices)
|
||||
self.trainer.accelerator_backend.setup_optimizers(self.trainer)
|
||||
|
||||
def __repr__(self):
|
||||
super_repr = super().__repr__()
|
||||
return f"{super_repr}"
|
||||
|
||||
|
||||
class SiameseGLVQ(GLVQ):
|
||||
"""GLVQ in a Siamese setting.
|
||||
@@ -212,7 +139,7 @@ class SiameseGLVQ(GLVQ):
|
||||
else:
|
||||
return optimizer
|
||||
|
||||
def _forward(self, x):
|
||||
def compute_distances(self, x):
|
||||
protos, _ = self.proto_layer()
|
||||
latent_x = self.backbone(x)
|
||||
self.backbone.requires_grad_(self.both_path_gradients)
|
||||
@@ -256,7 +183,7 @@ class GRLVQ(SiameseGLVQ):
|
||||
def relevance_profile(self):
|
||||
return self.relevances.detach().cpu()
|
||||
|
||||
def _forward(self, x):
|
||||
def compute_distances(self, x):
|
||||
protos, _ = self.proto_layer()
|
||||
distances = self.distance_layer(x, protos, torch.diag(self.relevances))
|
||||
return distances
|
||||
@@ -285,7 +212,7 @@ class SiameseGMLVQ(SiameseGLVQ):
|
||||
lam = omega.T @ omega
|
||||
return lam.detach().cpu()
|
||||
|
||||
def _forward(self, x):
|
||||
def compute_distances(self, x):
|
||||
protos, _ = self.proto_layer()
|
||||
x, protos = get_flat(x, protos)
|
||||
latent_x = self.backbone(x)
|
||||
@@ -305,7 +232,7 @@ class LVQMLN(SiameseGLVQ):
|
||||
rather in the embedding space.
|
||||
|
||||
"""
|
||||
def _forward(self, x):
|
||||
def compute_distances(self, x):
|
||||
latent_protos, _ = self.proto_layer()
|
||||
latent_x = self.backbone(x)
|
||||
distances = self.distance_layer(latent_x, latent_protos)
|
||||
@@ -327,7 +254,7 @@ class GMLVQ(GLVQ):
|
||||
device=self.device)
|
||||
self.register_parameter("_omega", Parameter(omega))
|
||||
|
||||
def _forward(self, x):
|
||||
def compute_distances(self, x):
|
||||
protos, _ = self.proto_layer()
|
||||
distances = self.distance_layer(x, protos, self._omega)
|
||||
return distances
|
||||
@@ -355,7 +282,7 @@ class GLVQ1(GLVQ):
|
||||
"""Generalized Learning Vector Quantization 1."""
|
||||
def __init__(self, hparams, **kwargs):
|
||||
super().__init__(hparams, **kwargs)
|
||||
self.loss = lvq1_loss
|
||||
self.loss = LossLayer(lvq1_loss)
|
||||
self.optimizer = torch.optim.SGD
|
||||
|
||||
|
||||
@@ -363,11 +290,11 @@ class GLVQ21(GLVQ):
|
||||
"""Generalized Learning Vector Quantization 2.1."""
|
||||
def __init__(self, hparams, **kwargs):
|
||||
super().__init__(hparams, **kwargs)
|
||||
self.loss = lvq21_loss
|
||||
self.loss = LossLayer(lvq21_loss)
|
||||
self.optimizer = torch.optim.SGD
|
||||
|
||||
|
||||
class ImageGLVQ(PrototypeImageModel, GLVQ):
|
||||
class ImageGLVQ(ImagePrototypesMixin, GLVQ):
|
||||
"""GLVQ for training on image data.
|
||||
|
||||
GLVQ model that constrains the prototypes to the range [0, 1] by clamping
|
||||
@@ -376,7 +303,7 @@ class ImageGLVQ(PrototypeImageModel, GLVQ):
|
||||
"""
|
||||
|
||||
|
||||
class ImageGMLVQ(PrototypeImageModel, GMLVQ):
|
||||
class ImageGMLVQ(ImagePrototypesMixin, GMLVQ):
|
||||
"""GMLVQ for training on image data.
|
||||
|
||||
GMLVQ model that constrains the prototypes to the range [0, 1] by clamping
|
||||
|
Reference in New Issue
Block a user