diff --git a/prototorch/models/abstract.py b/prototorch/models/abstract.py index d7f57c5..c42a7b5 100644 --- a/prototorch/models/abstract.py +++ b/prototorch/models/abstract.py @@ -1,5 +1,6 @@ import pytorch_lightning as pl import torch +from prototorch.functions.competitions import wtac from torch.optim.lr_scheduler import ExponentialLR @@ -29,3 +30,33 @@ class AbstractPrototypeModel(pl.LightningModule): class PrototypeImageModel(pl.LightningModule): def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx): self.proto_layer.components.data.clamp_(0.0, 1.0) + + +class SiamesePrototypeModel(pl.LightningModule): + def configure_optimizers(self): + proto_opt = self.optimizer(self.proto_layer.parameters(), + lr=self.hparams.proto_lr) + if list(self.backbone.parameters()): + # only add an optimizer is the backbone has trainable parameters + # otherwise, the next line fails + bb_opt = self.optimizer(self.backbone.parameters(), + lr=self.hparams.bb_lr) + return proto_opt, bb_opt + else: + return proto_opt + + def predict_latent(self, x, map_protos=True): + """Predict `x` assuming it is already embedded in the latent space. + + Only the prototypes are embedded in the latent space using the + backbone. + + """ + # model.eval() # ?! + with torch.no_grad(): + protos, plabels = self.proto_layer() + if map_protos: + protos = self.backbone(protos) + d = self.distance_fn(x, protos) + y_pred = wtac(d, plabels) + return y_pred diff --git a/prototorch/models/glvq.py b/prototorch/models/glvq.py index 48c195a..e725a17 100644 --- a/prototorch/models/glvq.py +++ b/prototorch/models/glvq.py @@ -4,11 +4,12 @@ from prototorch.components import LabeledComponents from prototorch.functions.activations import get_activation from prototorch.functions.competitions import wtac from prototorch.functions.distances import (euclidean_distance, omega_distance, - squared_euclidean_distance) + sed) from prototorch.functions.helper import get_flat from prototorch.functions.losses import glvq_loss, lvq1_loss, lvq21_loss -from .abstract import AbstractPrototypeModel, PrototypeImageModel +from .abstract import (AbstractPrototypeModel, PrototypeImageModel, + SiamesePrototypeModel) class GLVQ(AbstractPrototypeModel): @@ -25,11 +26,11 @@ class GLVQ(AbstractPrototypeModel): self.save_hyperparameters(hparams) + self.distance_fn = kwargs.get("distance_fn", euclidean_distance) self.optimizer = kwargs.get("optimizer", torch.optim.Adam) prototype_initializer = kwargs.get("prototype_initializer", None) # Default Values - self.hparams.setdefault("distance", euclidean_distance) self.hparams.setdefault("transfer_function", "identity") self.hparams.setdefault("transfer_beta", 10.0) @@ -48,7 +49,7 @@ class GLVQ(AbstractPrototypeModel): def forward(self, x): protos, _ = self.proto_layer() - dis = self.hparams.distance(x, protos) + dis = self.distance_fn(x, protos) return dis def training_step(self, train_batch, batch_idx, optimizer_idx=None): @@ -87,33 +88,7 @@ class GLVQ(AbstractPrototypeModel): return y_pred -class LVQ1(GLVQ): - """Learning Vector Quantization 1.""" - def __init__(self, hparams, **kwargs): - super().__init__(hparams, **kwargs) - self.loss = lvq1_loss - self.optimizer = torch.optim.SGD - - -class LVQ21(GLVQ): - """Learning Vector Quantization 2.1.""" - def __init__(self, hparams, **kwargs): - super().__init__(hparams, **kwargs) - self.loss = lvq21_loss - self.optimizer = torch.optim.SGD - - -class ImageGLVQ(GLVQ, PrototypeImageModel): - """GLVQ for training on image data. - - GLVQ model that constrains the prototypes to the range [0, 1] by clamping - after updates. - - """ - pass - - -class SiameseGLVQ(GLVQ): +class SiameseGLVQ(SiamesePrototypeModel, GLVQ): """GLVQ in a Siamese setting. GLVQ model that applies an arbitrary transformation on the inputs and the @@ -123,110 +98,62 @@ class SiameseGLVQ(GLVQ): """ def __init__(self, hparams, - backbone_module=torch.nn.Identity, - backbone_params={}, - sync=True, + backbone=torch.nn.Identity(), + both_path_gradients=False, **kwargs): super().__init__(hparams, **kwargs) - self.backbone = backbone_module(**backbone_params) - self.backbone_dependent = backbone_module( - **backbone_params).requires_grad_(False) - self.sync = sync - - def sync_backbones(self): - master_state = self.backbone.state_dict() - self.backbone_dependent.load_state_dict(master_state, strict=True) - - def configure_optimizers(self): - proto_opt = self.optimizer(self.proto_layer.parameters(), - lr=self.hparams.proto_lr) - if list(self.backbone.parameters()): - # only add an optimizer is the backbone has trainable parameters - # otherwise, the next line fails - bb_opt = self.optimizer(self.backbone.parameters(), - lr=self.hparams.bb_lr) - return proto_opt, bb_opt - else: - return proto_opt + self.backbone = backbone + self.both_path_gradients = both_path_gradients + self.distance_fn = kwargs.get("distance_fn", sed) def forward(self, x): - if self.sync: - self.sync_backbones() protos, _ = self.proto_layer() latent_x = self.backbone(x) - latent_protos = self.backbone_dependent(protos) - dis = euclidean_distance(latent_x, latent_protos) + self.backbone.requires_grad_(self.both_path_gradients) + latent_protos = self.backbone(protos) + self.backbone.requires_grad_(True) + dis = self.distance_fn(latent_x, latent_protos) return dis - def predict_latent(self, x): - """Predict `x` assuming it is already embedded in the latent space. - Only the prototypes are embedded in the latent space using the - backbone. - - """ - # model.eval() # ?! - with torch.no_grad(): - protos, plabels = self.proto_layer() - latent_protos = self.backbone_dependent(protos) - d = euclidean_distance(x, latent_protos) - y_pred = wtac(d, plabels) - return y_pred - - -class GRLVQ(GLVQ): +class GRLVQ(SiamesePrototypeModel, GLVQ): """Generalized Relevance Learning Vector Quantization.""" def __init__(self, hparams, **kwargs): super().__init__(hparams, **kwargs) self.relevances = torch.nn.parameter.Parameter( torch.ones(self.hparams.input_dim)) + self.distance_fn = kwargs.get("distance_fn", sed) + + @property + def relevance_profile(self): + return self.relevances.detach().cpu() + + def backbone(self, x): + """Namespace hook for the visualization callbacks to work.""" + return x @ torch.diag(self.relevances) def forward(self, x): protos, _ = self.proto_layer() dis = omega_distance(x, protos, torch.diag(self.relevances)) return dis - def backbone(self, x): - return x @ torch.diag(self.relevances) - @property - def relevance_profile(self): - return self.relevances.detach().cpu() - - def predict_latent(self, x): - """Predict `x` assuming it is already embedded in the latent space. - - Only the prototypes are embedded in the latent space using the - backbone. - - """ - # model.eval() # ?! - with torch.no_grad(): - protos, plabels = self.proto_layer() - latent_protos = protos @ torch.diag(self.relevances) - d = squared_euclidean_distance(x, latent_protos) - y_pred = wtac(d, plabels) - return y_pred - - -class GMLVQ(GLVQ): +class GMLVQ(SiamesePrototypeModel, GLVQ): """Generalized Matrix Learning Vector Quantization.""" def __init__(self, hparams, **kwargs): super().__init__(hparams, **kwargs) - self.omega_layer = torch.nn.Linear(self.hparams.input_dim, - self.hparams.latent_dim, - bias=False) - - # Namespace hook for the visualization callbacks to work - self.backbone = self.omega_layer + self.backbone = torch.nn.Linear(self.hparams.input_dim, + self.hparams.latent_dim, + bias=False) + self.distance_fn = kwargs.get("distance_fn", sed) @property def omega_matrix(self): - return self.omega_layer.weight.detach().cpu() + return self.backbone.weight.detach().cpu() @property def lambda_matrix(self): - omega = self.omega_layer.weight # (latent_dim, input_dim) + omega = self.backbone.weight # (latent_dim, input_dim) lam = omega.T @ omega return lam.detach().cpu() @@ -243,38 +170,13 @@ class GMLVQ(GLVQ): def forward(self, x): protos, _ = self.proto_layer() x, protos = get_flat(x, protos) - latent_x = self.omega_layer(x) - latent_protos = self.omega_layer(protos) - dis = squared_euclidean_distance(latent_x, latent_protos) + latent_x = self.backbone(x) + latent_protos = self.backbone(protos) + dis = self.distance_fn(latent_x, latent_protos) return dis - def predict_latent(self, x): - """Predict `x` assuming it is already embedded in the latent space. - Only the prototypes are embedded in the latent space using the - backbone. - - """ - # model.eval() # ?! - with torch.no_grad(): - protos, plabels = self.proto_layer() - latent_protos = self.omega_layer(protos) - d = squared_euclidean_distance(x, latent_protos) - y_pred = wtac(d, plabels) - return y_pred - - -class ImageGMLVQ(GMLVQ, PrototypeImageModel): - """GMLVQ for training on image data. - - GMLVQ model that constrains the prototypes to the range [0, 1] by clamping - after updates. - - """ - pass - - -class LVQMLN(GLVQ): +class LVQMLN(SiamesePrototypeModel, GLVQ): """Learning Vector Quantization Multi-Layer Network. GLVQ model that applies an arbitrary transformation on the inputs, BUT NOT @@ -283,27 +185,50 @@ class LVQMLN(GLVQ): rather in the embedding space. """ - def __init__(self, - hparams, - backbone_module=torch.nn.Identity, - backbone_params={}, - **kwargs): + def __init__(self, hparams, backbone=torch.nn.Identity(), **kwargs): super().__init__(hparams, **kwargs) - self.backbone = backbone_module(**backbone_params) - with torch.no_grad(): - protos = self.backbone(self.proto_layer()[0]) - self.proto_layer.load_state_dict({"_components": protos}, strict=False) + self.backbone = backbone + + self.distance_fn = kwargs.get("distance_fn", sed) def forward(self, x): latent_protos, _ = self.proto_layer() latent_x = self.backbone(x) - dis = euclidean_distance(latent_x, latent_protos) + dis = self.distance_fn(latent_x, latent_protos) return dis - def predict_latent(self, x): - """Predict `x` assuming it is already embedded in the latent space.""" - with torch.no_grad(): - latent_protos, plabels = self.proto_layer() - d = euclidean_distance(x, latent_protos) - y_pred = wtac(d, plabels) - return y_pred + +class LVQ1(GLVQ): + """Learning Vector Quantization 1.""" + def __init__(self, hparams, **kwargs): + super().__init__(hparams, **kwargs) + self.loss = lvq1_loss + self.optimizer = torch.optim.SGD + + +class LVQ21(GLVQ): + """Learning Vector Quantization 2.1.""" + def __init__(self, hparams, **kwargs): + super().__init__(hparams, **kwargs) + self.loss = lvq21_loss + self.optimizer = torch.optim.SGD + + +class ImageGLVQ(PrototypeImageModel, GLVQ): + """GLVQ for training on image data. + + GLVQ model that constrains the prototypes to the range [0, 1] by clamping + after updates. + + """ + pass + + +class ImageGMLVQ(PrototypeImageModel, GMLVQ): + """GMLVQ for training on image data. + + GMLVQ model that constrains the prototypes to the range [0, 1] by clamping + after updates. + + """ + pass