2021-05-25 18:37:34 +00:00
|
|
|
"""Models based on the GLVQ framework."""
|
|
|
|
|
2021-04-21 12:51:34 +00:00
|
|
|
import torch
|
2021-10-11 13:45:43 +00:00
|
|
|
from prototorch.core.competitions import wtac
|
|
|
|
from prototorch.core.distances import lomega_distance, omega_distance, squared_euclidean_distance
|
|
|
|
from prototorch.core.initializers import EyeTransformInitializer
|
|
|
|
from prototorch.core.losses import GLVQLoss, lvq1_loss, lvq21_loss
|
|
|
|
from prototorch.core.transforms import LinearTransform
|
|
|
|
from prototorch.nn.wrappers import LambdaLayer, LossLayer
|
2021-06-01 21:44:16 +00:00
|
|
|
from torch.nn.parameter import Parameter
|
2021-05-12 14:36:22 +00:00
|
|
|
|
2021-10-11 14:05:12 +00:00
|
|
|
from .abstract import SupervisedPrototypeModel
|
|
|
|
from .mixin import ImagePrototypesMixin
|
2021-04-29 15:05:41 +00:00
|
|
|
|
2021-04-21 12:51:34 +00:00
|
|
|
|
2021-06-04 20:20:32 +00:00
|
|
|
class GLVQ(SupervisedPrototypeModel):
|
2021-04-21 12:51:34 +00:00
|
|
|
"""Generalized Learning Vector Quantization."""
|
2021-04-21 19:59:19 +00:00
|
|
|
def __init__(self, hparams, **kwargs):
|
2021-06-04 20:20:32 +00:00
|
|
|
super().__init__(hparams, **kwargs)
|
2021-06-01 21:44:16 +00:00
|
|
|
|
2021-06-04 20:20:32 +00:00
|
|
|
# Default hparams
|
2021-07-06 15:09:21 +00:00
|
|
|
self.hparams.setdefault("margin", 0.0)
|
2021-05-19 14:30:19 +00:00
|
|
|
self.hparams.setdefault("transfer_fn", "identity")
|
2021-05-04 18:56:16 +00:00
|
|
|
self.hparams.setdefault("transfer_beta", 10.0)
|
2021-05-31 09:19:06 +00:00
|
|
|
|
2021-06-04 20:20:32 +00:00
|
|
|
# Loss
|
2021-07-06 15:09:21 +00:00
|
|
|
self.loss = GLVQLoss(
|
|
|
|
margin=self.hparams.margin,
|
|
|
|
transfer_fn=self.hparams.transfer_fn,
|
|
|
|
beta=self.hparams.transfer_beta,
|
|
|
|
)
|
2021-04-21 12:51:34 +00:00
|
|
|
|
2021-06-02 00:32:54 +00:00
|
|
|
def initialize_prototype_win_ratios(self):
|
2021-06-03 12:35:24 +00:00
|
|
|
self.register_buffer(
|
|
|
|
"prototype_win_ratios",
|
|
|
|
torch.zeros(self.num_prototypes, device=self.device))
|
|
|
|
|
|
|
|
def on_epoch_start(self):
|
|
|
|
self.initialize_prototype_win_ratios()
|
2021-06-02 00:32:54 +00:00
|
|
|
|
|
|
|
def log_prototype_win_ratios(self, distances):
|
|
|
|
batch_size = len(distances)
|
|
|
|
prototype_wc = torch.zeros(self.num_prototypes,
|
|
|
|
dtype=torch.long,
|
|
|
|
device=self.device)
|
|
|
|
wi, wc = torch.unique(distances.min(dim=-1).indices,
|
|
|
|
sorted=True,
|
|
|
|
return_counts=True)
|
|
|
|
prototype_wc[wi] = wc
|
|
|
|
prototype_wr = prototype_wc / batch_size
|
|
|
|
self.prototype_win_ratios = torch.vstack([
|
|
|
|
self.prototype_win_ratios,
|
|
|
|
prototype_wr,
|
|
|
|
])
|
|
|
|
|
2021-05-19 14:57:51 +00:00
|
|
|
def shared_step(self, batch, batch_idx, optimizer_idx=None):
|
|
|
|
x, y = batch
|
2021-06-04 20:20:32 +00:00
|
|
|
out = self.compute_distances(x)
|
2021-08-05 07:14:32 +00:00
|
|
|
_, plabels = self.proto_layer()
|
2021-07-06 15:09:21 +00:00
|
|
|
loss = self.loss(out, y, plabels)
|
2021-05-19 14:57:51 +00:00
|
|
|
return out, loss
|
2021-05-18 17:49:16 +00:00
|
|
|
|
2021-05-19 14:57:51 +00:00
|
|
|
def training_step(self, batch, batch_idx, optimizer_idx=None):
|
|
|
|
out, train_loss = self.shared_step(batch, batch_idx, optimizer_idx)
|
2021-06-02 00:32:54 +00:00
|
|
|
self.log_prototype_win_ratios(out)
|
2021-05-20 15:36:00 +00:00
|
|
|
self.log("train_loss", train_loss)
|
2021-05-19 14:57:51 +00:00
|
|
|
self.log_acc(out, batch[-1], tag="train_acc")
|
2021-05-19 14:30:19 +00:00
|
|
|
return train_loss
|
|
|
|
|
2021-05-19 14:57:51 +00:00
|
|
|
def validation_step(self, batch, batch_idx):
|
|
|
|
# `model.eval()` and `torch.no_grad()` handled by pl
|
2021-05-20 11:17:27 +00:00
|
|
|
out, val_loss = self.shared_step(batch, batch_idx)
|
2021-05-19 14:30:19 +00:00
|
|
|
self.log("val_loss", val_loss)
|
2021-05-19 14:57:51 +00:00
|
|
|
self.log_acc(out, batch[-1], tag="val_acc")
|
2021-05-19 14:30:19 +00:00
|
|
|
return val_loss
|
|
|
|
|
2021-05-19 14:57:51 +00:00
|
|
|
def test_step(self, batch, batch_idx):
|
|
|
|
# `model.eval()` and `torch.no_grad()` handled by pl
|
2021-05-20 11:17:27 +00:00
|
|
|
out, test_loss = self.shared_step(batch, batch_idx)
|
2021-05-20 12:03:31 +00:00
|
|
|
self.log_acc(out, batch[-1], tag="test_acc")
|
2021-05-20 12:20:23 +00:00
|
|
|
return test_loss
|
|
|
|
|
|
|
|
def test_epoch_end(self, outputs):
|
2021-05-20 12:40:02 +00:00
|
|
|
test_loss = 0.0
|
2021-05-20 12:20:23 +00:00
|
|
|
for batch_loss in outputs:
|
2021-05-20 12:40:02 +00:00
|
|
|
test_loss += batch_loss.item()
|
|
|
|
self.log("test_loss", test_loss)
|
2021-04-21 17:16:57 +00:00
|
|
|
|
2021-06-02 00:32:54 +00:00
|
|
|
# TODO
|
2021-05-19 14:57:51 +00:00
|
|
|
# def predict_step(self, batch, batch_idx, dataloader_idx=None):
|
|
|
|
# pass
|
|
|
|
|
2021-04-21 12:51:34 +00:00
|
|
|
|
2021-05-21 11:33:57 +00:00
|
|
|
class SiameseGLVQ(GLVQ):
|
2021-04-27 12:35:17 +00:00
|
|
|
"""GLVQ in a Siamese setting.
|
|
|
|
|
|
|
|
GLVQ model that applies an arbitrary transformation on the inputs and the
|
|
|
|
prototypes before computing the distances between them. The weights in the
|
|
|
|
transformation pipeline are only learned from the inputs.
|
2021-04-29 21:37:22 +00:00
|
|
|
|
2021-04-27 12:35:17 +00:00
|
|
|
"""
|
|
|
|
def __init__(self,
|
|
|
|
hparams,
|
2021-05-17 15:00:23 +00:00
|
|
|
backbone=torch.nn.Identity(),
|
|
|
|
both_path_gradients=False,
|
2021-04-27 12:35:17 +00:00
|
|
|
**kwargs):
|
2021-06-01 21:44:16 +00:00
|
|
|
distance_fn = kwargs.pop("distance_fn", squared_euclidean_distance)
|
|
|
|
super().__init__(hparams, distance_fn=distance_fn, **kwargs)
|
2021-05-17 15:00:23 +00:00
|
|
|
self.backbone = backbone
|
|
|
|
self.both_path_gradients = both_path_gradients
|
2021-05-03 11:20:49 +00:00
|
|
|
|
2021-05-21 11:33:57 +00:00
|
|
|
def configure_optimizers(self):
|
|
|
|
proto_opt = self.optimizer(self.proto_layer.parameters(),
|
|
|
|
lr=self.hparams.proto_lr)
|
2021-06-07 15:00:38 +00:00
|
|
|
# Only add a backbone optimizer if backbone has trainable parameters
|
2021-08-30 15:15:40 +00:00
|
|
|
bb_params = list(self.backbone.parameters())
|
|
|
|
if (bb_params):
|
2021-06-07 15:00:38 +00:00
|
|
|
bb_opt = self.optimizer(bb_params, lr=self.hparams.bb_lr)
|
|
|
|
optimizers = [proto_opt, bb_opt]
|
2021-05-21 11:33:57 +00:00
|
|
|
else:
|
2021-06-07 15:00:38 +00:00
|
|
|
optimizers = [proto_opt]
|
2021-06-04 13:55:06 +00:00
|
|
|
if self.lr_scheduler is not None:
|
2021-06-07 15:00:38 +00:00
|
|
|
schedulers = []
|
|
|
|
for optimizer in optimizers:
|
|
|
|
scheduler = self.lr_scheduler(optimizer,
|
|
|
|
**self.lr_scheduler_kwargs)
|
|
|
|
schedulers.append(scheduler)
|
|
|
|
return optimizers, schedulers
|
2021-06-04 13:55:06 +00:00
|
|
|
else:
|
2021-06-07 15:00:38 +00:00
|
|
|
return optimizers
|
2021-05-21 11:33:57 +00:00
|
|
|
|
2021-06-04 20:20:32 +00:00
|
|
|
def compute_distances(self, x):
|
2021-04-29 15:05:41 +00:00
|
|
|
protos, _ = self.proto_layer()
|
2021-10-13 08:54:53 +00:00
|
|
|
x, protos = (arr.view(arr.size(0), -1) for arr in (x, protos))
|
2021-04-27 12:35:17 +00:00
|
|
|
latent_x = self.backbone(x)
|
2021-05-17 15:00:23 +00:00
|
|
|
self.backbone.requires_grad_(self.both_path_gradients)
|
|
|
|
latent_protos = self.backbone(protos)
|
|
|
|
self.backbone.requires_grad_(True)
|
2021-06-01 21:44:16 +00:00
|
|
|
distances = self.distance_layer(latent_x, latent_protos)
|
2021-05-20 15:36:00 +00:00
|
|
|
return distances
|
2021-04-27 12:35:17 +00:00
|
|
|
|
2021-05-21 11:33:57 +00:00
|
|
|
def predict_latent(self, x, map_protos=True):
|
|
|
|
"""Predict `x` assuming it is already embedded in the latent space.
|
|
|
|
|
|
|
|
Only the prototypes are embedded in the latent space using the
|
|
|
|
backbone.
|
|
|
|
|
|
|
|
"""
|
|
|
|
self.eval()
|
|
|
|
with torch.no_grad():
|
|
|
|
protos, plabels = self.proto_layer()
|
|
|
|
if map_protos:
|
|
|
|
protos = self.backbone(protos)
|
2021-06-01 21:44:16 +00:00
|
|
|
d = self.distance_layer(x, protos)
|
2021-05-21 11:33:57 +00:00
|
|
|
y_pred = wtac(d, plabels)
|
|
|
|
return y_pred
|
|
|
|
|
2021-04-29 21:37:22 +00:00
|
|
|
|
2021-06-07 15:00:38 +00:00
|
|
|
class LVQMLN(SiameseGLVQ):
|
|
|
|
"""Learning Vector Quantization Multi-Layer Network.
|
|
|
|
|
|
|
|
GLVQ model that applies an arbitrary transformation on the inputs, BUT NOT
|
|
|
|
on the prototypes before computing the distances between them. This of
|
|
|
|
course, means that the prototypes no longer live the input space, but
|
|
|
|
rather in the embedding space.
|
|
|
|
|
|
|
|
"""
|
|
|
|
def compute_distances(self, x):
|
|
|
|
latent_protos, _ = self.proto_layer()
|
|
|
|
latent_x = self.backbone(x)
|
|
|
|
distances = self.distance_layer(latent_x, latent_protos)
|
|
|
|
return distances
|
|
|
|
|
|
|
|
|
2021-05-20 15:36:00 +00:00
|
|
|
class GRLVQ(SiameseGLVQ):
|
2021-06-01 21:44:16 +00:00
|
|
|
"""Generalized Relevance Learning Vector Quantization.
|
2021-05-06 16:42:06 +00:00
|
|
|
|
2021-06-07 15:00:38 +00:00
|
|
|
Implemented as a Siamese network with a linear transformation backbone.
|
|
|
|
|
2021-06-01 21:44:16 +00:00
|
|
|
TODO Make a RelevanceLayer. `bb_lr` is ignored otherwise.
|
2021-06-07 15:00:38 +00:00
|
|
|
|
2021-06-01 21:44:16 +00:00
|
|
|
"""
|
|
|
|
def __init__(self, hparams, **kwargs):
|
2021-06-07 15:00:38 +00:00
|
|
|
super().__init__(hparams, **kwargs)
|
|
|
|
|
|
|
|
# Additional parameters
|
2021-06-01 21:44:16 +00:00
|
|
|
relevances = torch.ones(self.hparams.input_dim, device=self.device)
|
|
|
|
self.register_parameter("_relevances", Parameter(relevances))
|
2021-06-07 15:00:38 +00:00
|
|
|
|
|
|
|
# Override the backbone
|
|
|
|
self.backbone = LambdaLayer(lambda x: x @ torch.diag(self._relevances),
|
|
|
|
name="relevance scaling")
|
2021-05-21 13:42:45 +00:00
|
|
|
|
2021-05-06 16:42:06 +00:00
|
|
|
@property
|
|
|
|
def relevance_profile(self):
|
2021-06-07 15:00:38 +00:00
|
|
|
return self._relevances.detach().cpu()
|
2021-05-06 16:42:06 +00:00
|
|
|
|
2021-06-07 15:00:38 +00:00
|
|
|
def extra_repr(self):
|
|
|
|
return f"(relevances): (shape: {tuple(self._relevances.shape)})"
|
2021-05-06 16:42:06 +00:00
|
|
|
|
|
|
|
|
2021-06-01 21:44:16 +00:00
|
|
|
class SiameseGMLVQ(SiameseGLVQ):
|
|
|
|
"""Generalized Matrix Learning Vector Quantization.
|
|
|
|
|
|
|
|
Implemented as a Siamese network with a linear transformation backbone.
|
|
|
|
|
|
|
|
"""
|
2021-04-29 21:37:22 +00:00
|
|
|
def __init__(self, hparams, **kwargs):
|
|
|
|
super().__init__(hparams, **kwargs)
|
2021-06-07 15:00:38 +00:00
|
|
|
|
|
|
|
# Override the backbone
|
2021-06-21 20:52:22 +00:00
|
|
|
omega_initializer = kwargs.get("omega_initializer",
|
|
|
|
EyeTransformInitializer())
|
|
|
|
self.backbone = LinearTransform(
|
|
|
|
self.hparams.input_dim,
|
|
|
|
self.hparams.output_dim,
|
|
|
|
initializer=omega_initializer,
|
|
|
|
)
|
2021-05-09 18:53:31 +00:00
|
|
|
|
2021-05-07 13:24:47 +00:00
|
|
|
@property
|
|
|
|
def omega_matrix(self):
|
2021-06-21 20:52:22 +00:00
|
|
|
return self.backbone.weights
|
2021-05-07 13:24:47 +00:00
|
|
|
|
|
|
|
@property
|
|
|
|
def lambda_matrix(self):
|
2021-06-21 20:52:22 +00:00
|
|
|
omega = self.backbone.weight # (input_dim, latent_dim)
|
|
|
|
lam = omega @ omega.T
|
2021-05-07 13:24:47 +00:00
|
|
|
return lam.detach().cpu()
|
|
|
|
|
2021-05-17 15:00:23 +00:00
|
|
|
|
2021-06-01 21:44:16 +00:00
|
|
|
class GMLVQ(GLVQ):
|
|
|
|
"""Generalized Matrix Learning Vector Quantization.
|
|
|
|
|
|
|
|
Implemented as a regular GLVQ network that simply uses a different distance
|
2021-06-07 15:00:38 +00:00
|
|
|
function. This makes it easier to implement a localized variant.
|
2021-06-01 21:44:16 +00:00
|
|
|
|
|
|
|
"""
|
2021-05-27 15:40:16 +00:00
|
|
|
def __init__(self, hparams, **kwargs):
|
2021-06-01 21:44:16 +00:00
|
|
|
distance_fn = kwargs.pop("distance_fn", omega_distance)
|
|
|
|
super().__init__(hparams, distance_fn=distance_fn, **kwargs)
|
2021-06-07 15:00:38 +00:00
|
|
|
|
|
|
|
# Additional parameters
|
2021-06-20 17:00:12 +00:00
|
|
|
omega_initializer = kwargs.get("omega_initializer",
|
|
|
|
EyeTransformInitializer())
|
|
|
|
omega = omega_initializer.generate(self.hparams.input_dim,
|
|
|
|
self.hparams.latent_dim)
|
2021-06-01 21:44:16 +00:00
|
|
|
self.register_parameter("_omega", Parameter(omega))
|
2021-06-20 17:00:12 +00:00
|
|
|
self.backbone = LambdaLayer(lambda x: x @ self._omega,
|
|
|
|
name="omega matrix")
|
2021-05-27 15:40:16 +00:00
|
|
|
|
2021-06-07 15:00:38 +00:00
|
|
|
@property
|
|
|
|
def omega_matrix(self):
|
|
|
|
return self._omega.detach().cpu()
|
|
|
|
|
2021-09-01 08:49:57 +00:00
|
|
|
@property
|
|
|
|
def lambda_matrix(self):
|
|
|
|
omega = self._omega.detach() # (input_dim, latent_dim)
|
|
|
|
lam = omega @ omega.T
|
|
|
|
return lam.detach().cpu()
|
|
|
|
|
2021-06-04 20:20:32 +00:00
|
|
|
def compute_distances(self, x):
|
2021-06-01 21:44:16 +00:00
|
|
|
protos, _ = self.proto_layer()
|
|
|
|
distances = self.distance_layer(x, protos, self._omega)
|
|
|
|
return distances
|
|
|
|
|
|
|
|
def extra_repr(self):
|
|
|
|
return f"(omega): (shape: {tuple(self._omega.shape)})"
|
|
|
|
|
|
|
|
|
|
|
|
class LGMLVQ(GMLVQ):
|
|
|
|
"""Localized and Generalized Matrix Learning Vector Quantization."""
|
|
|
|
def __init__(self, hparams, **kwargs):
|
|
|
|
distance_fn = kwargs.pop("distance_fn", lomega_distance)
|
|
|
|
super().__init__(hparams, distance_fn=distance_fn, **kwargs)
|
2021-06-07 15:00:38 +00:00
|
|
|
|
2021-06-01 21:44:16 +00:00
|
|
|
# Re-register `_omega` to override the one from the super class.
|
|
|
|
omega = torch.randn(
|
|
|
|
self.num_prototypes,
|
|
|
|
self.hparams.input_dim,
|
|
|
|
self.hparams.latent_dim,
|
|
|
|
device=self.device,
|
|
|
|
)
|
|
|
|
self.register_parameter("_omega", Parameter(omega))
|
2021-05-27 15:40:16 +00:00
|
|
|
|
|
|
|
|
2021-05-18 17:49:16 +00:00
|
|
|
class GLVQ1(GLVQ):
|
2021-05-21 13:42:45 +00:00
|
|
|
"""Generalized Learning Vector Quantization 1."""
|
2021-05-17 15:00:23 +00:00
|
|
|
def __init__(self, hparams, **kwargs):
|
|
|
|
super().__init__(hparams, **kwargs)
|
2021-06-04 20:20:32 +00:00
|
|
|
self.loss = LossLayer(lvq1_loss)
|
2021-05-17 15:00:23 +00:00
|
|
|
self.optimizer = torch.optim.SGD
|
|
|
|
|
|
|
|
|
2021-05-18 17:49:16 +00:00
|
|
|
class GLVQ21(GLVQ):
|
2021-05-21 13:42:45 +00:00
|
|
|
"""Generalized Learning Vector Quantization 2.1."""
|
2021-05-17 15:00:23 +00:00
|
|
|
def __init__(self, hparams, **kwargs):
|
|
|
|
super().__init__(hparams, **kwargs)
|
2021-06-04 20:20:32 +00:00
|
|
|
self.loss = LossLayer(lvq21_loss)
|
2021-05-17 15:00:23 +00:00
|
|
|
self.optimizer = torch.optim.SGD
|
|
|
|
|
|
|
|
|
2021-06-04 20:20:32 +00:00
|
|
|
class ImageGLVQ(ImagePrototypesMixin, GLVQ):
|
2021-05-17 15:00:23 +00:00
|
|
|
"""GLVQ for training on image data.
|
|
|
|
|
|
|
|
GLVQ model that constrains the prototypes to the range [0, 1] by clamping
|
|
|
|
after updates.
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
2021-06-04 20:20:32 +00:00
|
|
|
class ImageGMLVQ(ImagePrototypesMixin, GMLVQ):
|
2021-05-17 15:00:23 +00:00
|
|
|
"""GMLVQ for training on image data.
|
|
|
|
|
|
|
|
GMLVQ model that constrains the prototypes to the range [0, 1] by clamping
|
|
|
|
after updates.
|
|
|
|
|
|
|
|
"""
|