Cleanup models
Siamese architectures no longer accept a `backbone_module`. They have to be initialized with an pre-initialized backbone object instead. This is so that the visualization callbacks could use the very same object for visualization purposes. Also, there's no longer a dependent copy of the backbone. It is managed simply with `requires_grad` instead.
This commit is contained in:
parent
7a87636ad7
commit
81346785bd
@ -1,5 +1,6 @@
|
|||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import torch
|
import torch
|
||||||
|
from prototorch.functions.competitions import wtac
|
||||||
from torch.optim.lr_scheduler import ExponentialLR
|
from torch.optim.lr_scheduler import ExponentialLR
|
||||||
|
|
||||||
|
|
||||||
@ -29,3 +30,33 @@ class AbstractPrototypeModel(pl.LightningModule):
|
|||||||
class PrototypeImageModel(pl.LightningModule):
|
class PrototypeImageModel(pl.LightningModule):
|
||||||
def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
|
def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
|
||||||
self.proto_layer.components.data.clamp_(0.0, 1.0)
|
self.proto_layer.components.data.clamp_(0.0, 1.0)
|
||||||
|
|
||||||
|
|
||||||
|
class SiamesePrototypeModel(pl.LightningModule):
|
||||||
|
def configure_optimizers(self):
|
||||||
|
proto_opt = self.optimizer(self.proto_layer.parameters(),
|
||||||
|
lr=self.hparams.proto_lr)
|
||||||
|
if list(self.backbone.parameters()):
|
||||||
|
# only add an optimizer is the backbone has trainable parameters
|
||||||
|
# otherwise, the next line fails
|
||||||
|
bb_opt = self.optimizer(self.backbone.parameters(),
|
||||||
|
lr=self.hparams.bb_lr)
|
||||||
|
return proto_opt, bb_opt
|
||||||
|
else:
|
||||||
|
return proto_opt
|
||||||
|
|
||||||
|
def predict_latent(self, x, map_protos=True):
|
||||||
|
"""Predict `x` assuming it is already embedded in the latent space.
|
||||||
|
|
||||||
|
Only the prototypes are embedded in the latent space using the
|
||||||
|
backbone.
|
||||||
|
|
||||||
|
"""
|
||||||
|
# model.eval() # ?!
|
||||||
|
with torch.no_grad():
|
||||||
|
protos, plabels = self.proto_layer()
|
||||||
|
if map_protos:
|
||||||
|
protos = self.backbone(protos)
|
||||||
|
d = self.distance_fn(x, protos)
|
||||||
|
y_pred = wtac(d, plabels)
|
||||||
|
return y_pred
|
||||||
|
@ -4,11 +4,12 @@ from prototorch.components import LabeledComponents
|
|||||||
from prototorch.functions.activations import get_activation
|
from prototorch.functions.activations import get_activation
|
||||||
from prototorch.functions.competitions import wtac
|
from prototorch.functions.competitions import wtac
|
||||||
from prototorch.functions.distances import (euclidean_distance, omega_distance,
|
from prototorch.functions.distances import (euclidean_distance, omega_distance,
|
||||||
squared_euclidean_distance)
|
sed)
|
||||||
from prototorch.functions.helper import get_flat
|
from prototorch.functions.helper import get_flat
|
||||||
from prototorch.functions.losses import glvq_loss, lvq1_loss, lvq21_loss
|
from prototorch.functions.losses import glvq_loss, lvq1_loss, lvq21_loss
|
||||||
|
|
||||||
from .abstract import AbstractPrototypeModel, PrototypeImageModel
|
from .abstract import (AbstractPrototypeModel, PrototypeImageModel,
|
||||||
|
SiamesePrototypeModel)
|
||||||
|
|
||||||
|
|
||||||
class GLVQ(AbstractPrototypeModel):
|
class GLVQ(AbstractPrototypeModel):
|
||||||
@ -25,11 +26,11 @@ class GLVQ(AbstractPrototypeModel):
|
|||||||
|
|
||||||
self.save_hyperparameters(hparams)
|
self.save_hyperparameters(hparams)
|
||||||
|
|
||||||
|
self.distance_fn = kwargs.get("distance_fn", euclidean_distance)
|
||||||
self.optimizer = kwargs.get("optimizer", torch.optim.Adam)
|
self.optimizer = kwargs.get("optimizer", torch.optim.Adam)
|
||||||
prototype_initializer = kwargs.get("prototype_initializer", None)
|
prototype_initializer = kwargs.get("prototype_initializer", None)
|
||||||
|
|
||||||
# Default Values
|
# Default Values
|
||||||
self.hparams.setdefault("distance", euclidean_distance)
|
|
||||||
self.hparams.setdefault("transfer_function", "identity")
|
self.hparams.setdefault("transfer_function", "identity")
|
||||||
self.hparams.setdefault("transfer_beta", 10.0)
|
self.hparams.setdefault("transfer_beta", 10.0)
|
||||||
|
|
||||||
@ -48,7 +49,7 @@ class GLVQ(AbstractPrototypeModel):
|
|||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
protos, _ = self.proto_layer()
|
protos, _ = self.proto_layer()
|
||||||
dis = self.hparams.distance(x, protos)
|
dis = self.distance_fn(x, protos)
|
||||||
return dis
|
return dis
|
||||||
|
|
||||||
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
|
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
|
||||||
@ -87,33 +88,7 @@ class GLVQ(AbstractPrototypeModel):
|
|||||||
return y_pred
|
return y_pred
|
||||||
|
|
||||||
|
|
||||||
class LVQ1(GLVQ):
|
class SiameseGLVQ(SiamesePrototypeModel, GLVQ):
|
||||||
"""Learning Vector Quantization 1."""
|
|
||||||
def __init__(self, hparams, **kwargs):
|
|
||||||
super().__init__(hparams, **kwargs)
|
|
||||||
self.loss = lvq1_loss
|
|
||||||
self.optimizer = torch.optim.SGD
|
|
||||||
|
|
||||||
|
|
||||||
class LVQ21(GLVQ):
|
|
||||||
"""Learning Vector Quantization 2.1."""
|
|
||||||
def __init__(self, hparams, **kwargs):
|
|
||||||
super().__init__(hparams, **kwargs)
|
|
||||||
self.loss = lvq21_loss
|
|
||||||
self.optimizer = torch.optim.SGD
|
|
||||||
|
|
||||||
|
|
||||||
class ImageGLVQ(GLVQ, PrototypeImageModel):
|
|
||||||
"""GLVQ for training on image data.
|
|
||||||
|
|
||||||
GLVQ model that constrains the prototypes to the range [0, 1] by clamping
|
|
||||||
after updates.
|
|
||||||
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class SiameseGLVQ(GLVQ):
|
|
||||||
"""GLVQ in a Siamese setting.
|
"""GLVQ in a Siamese setting.
|
||||||
|
|
||||||
GLVQ model that applies an arbitrary transformation on the inputs and the
|
GLVQ model that applies an arbitrary transformation on the inputs and the
|
||||||
@ -123,110 +98,62 @@ class SiameseGLVQ(GLVQ):
|
|||||||
"""
|
"""
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
hparams,
|
hparams,
|
||||||
backbone_module=torch.nn.Identity,
|
backbone=torch.nn.Identity(),
|
||||||
backbone_params={},
|
both_path_gradients=False,
|
||||||
sync=True,
|
|
||||||
**kwargs):
|
**kwargs):
|
||||||
super().__init__(hparams, **kwargs)
|
super().__init__(hparams, **kwargs)
|
||||||
self.backbone = backbone_module(**backbone_params)
|
self.backbone = backbone
|
||||||
self.backbone_dependent = backbone_module(
|
self.both_path_gradients = both_path_gradients
|
||||||
**backbone_params).requires_grad_(False)
|
self.distance_fn = kwargs.get("distance_fn", sed)
|
||||||
self.sync = sync
|
|
||||||
|
|
||||||
def sync_backbones(self):
|
|
||||||
master_state = self.backbone.state_dict()
|
|
||||||
self.backbone_dependent.load_state_dict(master_state, strict=True)
|
|
||||||
|
|
||||||
def configure_optimizers(self):
|
|
||||||
proto_opt = self.optimizer(self.proto_layer.parameters(),
|
|
||||||
lr=self.hparams.proto_lr)
|
|
||||||
if list(self.backbone.parameters()):
|
|
||||||
# only add an optimizer is the backbone has trainable parameters
|
|
||||||
# otherwise, the next line fails
|
|
||||||
bb_opt = self.optimizer(self.backbone.parameters(),
|
|
||||||
lr=self.hparams.bb_lr)
|
|
||||||
return proto_opt, bb_opt
|
|
||||||
else:
|
|
||||||
return proto_opt
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
if self.sync:
|
|
||||||
self.sync_backbones()
|
|
||||||
protos, _ = self.proto_layer()
|
protos, _ = self.proto_layer()
|
||||||
latent_x = self.backbone(x)
|
latent_x = self.backbone(x)
|
||||||
latent_protos = self.backbone_dependent(protos)
|
self.backbone.requires_grad_(self.both_path_gradients)
|
||||||
dis = euclidean_distance(latent_x, latent_protos)
|
latent_protos = self.backbone(protos)
|
||||||
|
self.backbone.requires_grad_(True)
|
||||||
|
dis = self.distance_fn(latent_x, latent_protos)
|
||||||
return dis
|
return dis
|
||||||
|
|
||||||
def predict_latent(self, x):
|
|
||||||
"""Predict `x` assuming it is already embedded in the latent space.
|
|
||||||
|
|
||||||
Only the prototypes are embedded in the latent space using the
|
class GRLVQ(SiamesePrototypeModel, GLVQ):
|
||||||
backbone.
|
|
||||||
|
|
||||||
"""
|
|
||||||
# model.eval() # ?!
|
|
||||||
with torch.no_grad():
|
|
||||||
protos, plabels = self.proto_layer()
|
|
||||||
latent_protos = self.backbone_dependent(protos)
|
|
||||||
d = euclidean_distance(x, latent_protos)
|
|
||||||
y_pred = wtac(d, plabels)
|
|
||||||
return y_pred
|
|
||||||
|
|
||||||
|
|
||||||
class GRLVQ(GLVQ):
|
|
||||||
"""Generalized Relevance Learning Vector Quantization."""
|
"""Generalized Relevance Learning Vector Quantization."""
|
||||||
def __init__(self, hparams, **kwargs):
|
def __init__(self, hparams, **kwargs):
|
||||||
super().__init__(hparams, **kwargs)
|
super().__init__(hparams, **kwargs)
|
||||||
self.relevances = torch.nn.parameter.Parameter(
|
self.relevances = torch.nn.parameter.Parameter(
|
||||||
torch.ones(self.hparams.input_dim))
|
torch.ones(self.hparams.input_dim))
|
||||||
|
self.distance_fn = kwargs.get("distance_fn", sed)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def relevance_profile(self):
|
||||||
|
return self.relevances.detach().cpu()
|
||||||
|
|
||||||
|
def backbone(self, x):
|
||||||
|
"""Namespace hook for the visualization callbacks to work."""
|
||||||
|
return x @ torch.diag(self.relevances)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
protos, _ = self.proto_layer()
|
protos, _ = self.proto_layer()
|
||||||
dis = omega_distance(x, protos, torch.diag(self.relevances))
|
dis = omega_distance(x, protos, torch.diag(self.relevances))
|
||||||
return dis
|
return dis
|
||||||
|
|
||||||
def backbone(self, x):
|
|
||||||
return x @ torch.diag(self.relevances)
|
|
||||||
|
|
||||||
@property
|
class GMLVQ(SiamesePrototypeModel, GLVQ):
|
||||||
def relevance_profile(self):
|
|
||||||
return self.relevances.detach().cpu()
|
|
||||||
|
|
||||||
def predict_latent(self, x):
|
|
||||||
"""Predict `x` assuming it is already embedded in the latent space.
|
|
||||||
|
|
||||||
Only the prototypes are embedded in the latent space using the
|
|
||||||
backbone.
|
|
||||||
|
|
||||||
"""
|
|
||||||
# model.eval() # ?!
|
|
||||||
with torch.no_grad():
|
|
||||||
protos, plabels = self.proto_layer()
|
|
||||||
latent_protos = protos @ torch.diag(self.relevances)
|
|
||||||
d = squared_euclidean_distance(x, latent_protos)
|
|
||||||
y_pred = wtac(d, plabels)
|
|
||||||
return y_pred
|
|
||||||
|
|
||||||
|
|
||||||
class GMLVQ(GLVQ):
|
|
||||||
"""Generalized Matrix Learning Vector Quantization."""
|
"""Generalized Matrix Learning Vector Quantization."""
|
||||||
def __init__(self, hparams, **kwargs):
|
def __init__(self, hparams, **kwargs):
|
||||||
super().__init__(hparams, **kwargs)
|
super().__init__(hparams, **kwargs)
|
||||||
self.omega_layer = torch.nn.Linear(self.hparams.input_dim,
|
self.backbone = torch.nn.Linear(self.hparams.input_dim,
|
||||||
self.hparams.latent_dim,
|
self.hparams.latent_dim,
|
||||||
bias=False)
|
bias=False)
|
||||||
|
self.distance_fn = kwargs.get("distance_fn", sed)
|
||||||
# Namespace hook for the visualization callbacks to work
|
|
||||||
self.backbone = self.omega_layer
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def omega_matrix(self):
|
def omega_matrix(self):
|
||||||
return self.omega_layer.weight.detach().cpu()
|
return self.backbone.weight.detach().cpu()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def lambda_matrix(self):
|
def lambda_matrix(self):
|
||||||
omega = self.omega_layer.weight # (latent_dim, input_dim)
|
omega = self.backbone.weight # (latent_dim, input_dim)
|
||||||
lam = omega.T @ omega
|
lam = omega.T @ omega
|
||||||
return lam.detach().cpu()
|
return lam.detach().cpu()
|
||||||
|
|
||||||
@ -243,38 +170,13 @@ class GMLVQ(GLVQ):
|
|||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
protos, _ = self.proto_layer()
|
protos, _ = self.proto_layer()
|
||||||
x, protos = get_flat(x, protos)
|
x, protos = get_flat(x, protos)
|
||||||
latent_x = self.omega_layer(x)
|
latent_x = self.backbone(x)
|
||||||
latent_protos = self.omega_layer(protos)
|
latent_protos = self.backbone(protos)
|
||||||
dis = squared_euclidean_distance(latent_x, latent_protos)
|
dis = self.distance_fn(latent_x, latent_protos)
|
||||||
return dis
|
return dis
|
||||||
|
|
||||||
def predict_latent(self, x):
|
|
||||||
"""Predict `x` assuming it is already embedded in the latent space.
|
|
||||||
|
|
||||||
Only the prototypes are embedded in the latent space using the
|
class LVQMLN(SiamesePrototypeModel, GLVQ):
|
||||||
backbone.
|
|
||||||
|
|
||||||
"""
|
|
||||||
# model.eval() # ?!
|
|
||||||
with torch.no_grad():
|
|
||||||
protos, plabels = self.proto_layer()
|
|
||||||
latent_protos = self.omega_layer(protos)
|
|
||||||
d = squared_euclidean_distance(x, latent_protos)
|
|
||||||
y_pred = wtac(d, plabels)
|
|
||||||
return y_pred
|
|
||||||
|
|
||||||
|
|
||||||
class ImageGMLVQ(GMLVQ, PrototypeImageModel):
|
|
||||||
"""GMLVQ for training on image data.
|
|
||||||
|
|
||||||
GMLVQ model that constrains the prototypes to the range [0, 1] by clamping
|
|
||||||
after updates.
|
|
||||||
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class LVQMLN(GLVQ):
|
|
||||||
"""Learning Vector Quantization Multi-Layer Network.
|
"""Learning Vector Quantization Multi-Layer Network.
|
||||||
|
|
||||||
GLVQ model that applies an arbitrary transformation on the inputs, BUT NOT
|
GLVQ model that applies an arbitrary transformation on the inputs, BUT NOT
|
||||||
@ -283,27 +185,50 @@ class LVQMLN(GLVQ):
|
|||||||
rather in the embedding space.
|
rather in the embedding space.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
def __init__(self,
|
def __init__(self, hparams, backbone=torch.nn.Identity(), **kwargs):
|
||||||
hparams,
|
|
||||||
backbone_module=torch.nn.Identity,
|
|
||||||
backbone_params={},
|
|
||||||
**kwargs):
|
|
||||||
super().__init__(hparams, **kwargs)
|
super().__init__(hparams, **kwargs)
|
||||||
self.backbone = backbone_module(**backbone_params)
|
self.backbone = backbone
|
||||||
with torch.no_grad():
|
|
||||||
protos = self.backbone(self.proto_layer()[0])
|
self.distance_fn = kwargs.get("distance_fn", sed)
|
||||||
self.proto_layer.load_state_dict({"_components": protos}, strict=False)
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
latent_protos, _ = self.proto_layer()
|
latent_protos, _ = self.proto_layer()
|
||||||
latent_x = self.backbone(x)
|
latent_x = self.backbone(x)
|
||||||
dis = euclidean_distance(latent_x, latent_protos)
|
dis = self.distance_fn(latent_x, latent_protos)
|
||||||
return dis
|
return dis
|
||||||
|
|
||||||
def predict_latent(self, x):
|
|
||||||
"""Predict `x` assuming it is already embedded in the latent space."""
|
class LVQ1(GLVQ):
|
||||||
with torch.no_grad():
|
"""Learning Vector Quantization 1."""
|
||||||
latent_protos, plabels = self.proto_layer()
|
def __init__(self, hparams, **kwargs):
|
||||||
d = euclidean_distance(x, latent_protos)
|
super().__init__(hparams, **kwargs)
|
||||||
y_pred = wtac(d, plabels)
|
self.loss = lvq1_loss
|
||||||
return y_pred
|
self.optimizer = torch.optim.SGD
|
||||||
|
|
||||||
|
|
||||||
|
class LVQ21(GLVQ):
|
||||||
|
"""Learning Vector Quantization 2.1."""
|
||||||
|
def __init__(self, hparams, **kwargs):
|
||||||
|
super().__init__(hparams, **kwargs)
|
||||||
|
self.loss = lvq21_loss
|
||||||
|
self.optimizer = torch.optim.SGD
|
||||||
|
|
||||||
|
|
||||||
|
class ImageGLVQ(PrototypeImageModel, GLVQ):
|
||||||
|
"""GLVQ for training on image data.
|
||||||
|
|
||||||
|
GLVQ model that constrains the prototypes to the range [0, 1] by clamping
|
||||||
|
after updates.
|
||||||
|
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ImageGMLVQ(PrototypeImageModel, GMLVQ):
|
||||||
|
"""GMLVQ for training on image data.
|
||||||
|
|
||||||
|
GMLVQ model that constrains the prototypes to the range [0, 1] by clamping
|
||||||
|
after updates.
|
||||||
|
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
Loading…
Reference in New Issue
Block a user