Add Local-Matrix LVQ
Also remove the use of `self.distance_fn` in favor of `self.distance_layer`.
This commit is contained in:
parent
5ec2dd47cd
commit
757f4e980d
@ -1,10 +1,23 @@
|
|||||||
|
"""`models` plugin for the `prototorch` package."""
|
||||||
|
|
||||||
from importlib.metadata import PackageNotFoundError, version
|
from importlib.metadata import PackageNotFoundError, version
|
||||||
|
|
||||||
from .probabilistic import LikelihoodRatioLVQ, RSLVQ
|
|
||||||
from .cbc import CBC, ImageCBC
|
from .cbc import CBC, ImageCBC
|
||||||
from .glvq import (CELVQ, GLVQ, GLVQ1, GLVQ21, GMLVQ, GRLVQ, LVQMLN, ImageGLVQ,
|
from .glvq import (
|
||||||
ImageGMLVQ, SiameseGLVQ)
|
GLVQ,
|
||||||
|
GLVQ1,
|
||||||
|
GLVQ21,
|
||||||
|
GMLVQ,
|
||||||
|
GRLVQ,
|
||||||
|
LGMLVQ,
|
||||||
|
LVQMLN,
|
||||||
|
ImageGLVQ,
|
||||||
|
ImageGMLVQ,
|
||||||
|
SiameseGLVQ,
|
||||||
|
SiameseGMLVQ,
|
||||||
|
)
|
||||||
from .lvq import LVQ1, LVQ21, MedianLVQ
|
from .lvq import LVQ1, LVQ21, MedianLVQ
|
||||||
|
from .probabilistic import CELVQ, RSLVQ, LikelihoodRatioLVQ
|
||||||
from .unsupervised import KNN, NeuralGas
|
from .unsupervised import KNN, NeuralGas
|
||||||
from .vis import *
|
from .vis import *
|
||||||
|
|
||||||
|
@ -4,12 +4,17 @@ import torch
|
|||||||
import torchmetrics
|
import torchmetrics
|
||||||
from prototorch.components import LabeledComponents
|
from prototorch.components import LabeledComponents
|
||||||
from prototorch.functions.activations import get_activation
|
from prototorch.functions.activations import get_activation
|
||||||
from prototorch.functions.competitions import stratified_min, wtac
|
from prototorch.functions.competitions import wtac
|
||||||
from prototorch.functions.distances import (euclidean_distance, omega_distance,
|
from prototorch.functions.distances import (
|
||||||
sed)
|
euclidean_distance,
|
||||||
|
lomega_distance,
|
||||||
|
omega_distance,
|
||||||
|
squared_euclidean_distance,
|
||||||
|
)
|
||||||
from prototorch.functions.helper import get_flat
|
from prototorch.functions.helper import get_flat
|
||||||
from prototorch.functions.losses import glvq_loss, lvq1_loss, lvq21_loss
|
from prototorch.functions.losses import glvq_loss, lvq1_loss, lvq21_loss
|
||||||
from prototorch.modules import LambdaLayer
|
from prototorch.modules import LambdaLayer
|
||||||
|
from torch.nn.parameter import Parameter
|
||||||
|
|
||||||
from .abstract import AbstractPrototypeModel, PrototypeImageModel
|
from .abstract import AbstractPrototypeModel, PrototypeImageModel
|
||||||
|
|
||||||
@ -17,17 +22,19 @@ from .abstract import AbstractPrototypeModel, PrototypeImageModel
|
|||||||
class GLVQ(AbstractPrototypeModel):
|
class GLVQ(AbstractPrototypeModel):
|
||||||
"""Generalized Learning Vector Quantization."""
|
"""Generalized Learning Vector Quantization."""
|
||||||
def __init__(self, hparams, **kwargs):
|
def __init__(self, hparams, **kwargs):
|
||||||
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
# Hyperparameters
|
# Hyperparameters
|
||||||
self.save_hyperparameters(hparams) # Default Values
|
self.save_hyperparameters(hparams)
|
||||||
|
|
||||||
|
# Defaults
|
||||||
self.hparams.setdefault("transfer_fn", "identity")
|
self.hparams.setdefault("transfer_fn", "identity")
|
||||||
self.hparams.setdefault("transfer_beta", 10.0)
|
self.hparams.setdefault("transfer_beta", 10.0)
|
||||||
self.hparams.setdefault("lr", 0.01)
|
self.hparams.setdefault("lr", 0.01)
|
||||||
|
|
||||||
distance_fn = kwargs.get("distance_fn", euclidean_distance)
|
distance_fn = kwargs.get("distance_fn", euclidean_distance)
|
||||||
tranfer_fn = get_activation(self.hparams.transfer_fn)
|
transfer_fn = get_activation(self.hparams.transfer_fn)
|
||||||
|
|
||||||
# Layers
|
# Layers
|
||||||
self.proto_layer = LabeledComponents(
|
self.proto_layer = LabeledComponents(
|
||||||
@ -35,7 +42,7 @@ class GLVQ(AbstractPrototypeModel):
|
|||||||
initializer=self.prototype_initializer(**kwargs))
|
initializer=self.prototype_initializer(**kwargs))
|
||||||
|
|
||||||
self.distance_layer = LambdaLayer(distance_fn)
|
self.distance_layer = LambdaLayer(distance_fn)
|
||||||
self.transfer_layer = LambdaLayer(tranfer_fn)
|
self.transfer_layer = LambdaLayer(transfer_fn)
|
||||||
self.loss = LambdaLayer(glvq_loss)
|
self.loss = LambdaLayer(glvq_loss)
|
||||||
|
|
||||||
self.optimizer = kwargs.get("optimizer", torch.optim.Adam)
|
self.optimizer = kwargs.get("optimizer", torch.optim.Adam)
|
||||||
@ -123,8 +130,12 @@ class GLVQ(AbstractPrototypeModel):
|
|||||||
# def predict_step(self, batch, batch_idx, dataloader_idx=None):
|
# def predict_step(self, batch, batch_idx, dataloader_idx=None):
|
||||||
# pass
|
# pass
|
||||||
|
|
||||||
def increase_prototypes(self, initializer, distribution):
|
def add_prototypes(self, initializer, distribution):
|
||||||
self.proto_layer.increase_components(initializer, distribution)
|
self.proto_layer.add_components(initializer, distribution)
|
||||||
|
self.trainer.accelerator_backend.setup_optimizers(self.trainer)
|
||||||
|
|
||||||
|
def remove_prototypes(self, indices):
|
||||||
|
self.proto_layer.remove_components(indices)
|
||||||
self.trainer.accelerator_backend.setup_optimizers(self.trainer)
|
self.trainer.accelerator_backend.setup_optimizers(self.trainer)
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
@ -145,10 +156,10 @@ class SiameseGLVQ(GLVQ):
|
|||||||
backbone=torch.nn.Identity(),
|
backbone=torch.nn.Identity(),
|
||||||
both_path_gradients=False,
|
both_path_gradients=False,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
super().__init__(hparams, **kwargs)
|
distance_fn = kwargs.pop("distance_fn", squared_euclidean_distance)
|
||||||
|
super().__init__(hparams, distance_fn=distance_fn, **kwargs)
|
||||||
self.backbone = backbone
|
self.backbone = backbone
|
||||||
self.both_path_gradients = both_path_gradients
|
self.both_path_gradients = both_path_gradients
|
||||||
self.distance_fn = kwargs.get("distance_fn", sed)
|
|
||||||
|
|
||||||
def configure_optimizers(self):
|
def configure_optimizers(self):
|
||||||
proto_opt = self.optimizer(self.proto_layer.parameters(),
|
proto_opt = self.optimizer(self.proto_layer.parameters(),
|
||||||
@ -168,7 +179,7 @@ class SiameseGLVQ(GLVQ):
|
|||||||
self.backbone.requires_grad_(self.both_path_gradients)
|
self.backbone.requires_grad_(self.both_path_gradients)
|
||||||
latent_protos = self.backbone(protos)
|
latent_protos = self.backbone(protos)
|
||||||
self.backbone.requires_grad_(True)
|
self.backbone.requires_grad_(True)
|
||||||
distances = self.distance_fn(latent_x, latent_protos)
|
distances = self.distance_layer(latent_x, latent_protos)
|
||||||
return distances
|
return distances
|
||||||
|
|
||||||
def predict_latent(self, x, map_protos=True):
|
def predict_latent(self, x, map_protos=True):
|
||||||
@ -183,39 +194,44 @@ class SiameseGLVQ(GLVQ):
|
|||||||
protos, plabels = self.proto_layer()
|
protos, plabels = self.proto_layer()
|
||||||
if map_protos:
|
if map_protos:
|
||||||
protos = self.backbone(protos)
|
protos = self.backbone(protos)
|
||||||
d = self.distance_fn(x, protos)
|
d = self.distance_layer(x, protos)
|
||||||
y_pred = wtac(d, plabels)
|
y_pred = wtac(d, plabels)
|
||||||
return y_pred
|
return y_pred
|
||||||
|
|
||||||
|
|
||||||
class GRLVQ(SiameseGLVQ):
|
class GRLVQ(SiameseGLVQ):
|
||||||
"""Generalized Relevance Learning Vector Quantization."""
|
"""Generalized Relevance Learning Vector Quantization.
|
||||||
def __init__(self, hparams, **kwargs):
|
|
||||||
super().__init__(hparams, **kwargs)
|
|
||||||
self.relevances = torch.nn.parameter.Parameter(
|
|
||||||
torch.ones(self.hparams.input_dim))
|
|
||||||
|
|
||||||
# Overwrite backbone
|
TODO Make a RelevanceLayer. `bb_lr` is ignored otherwise.
|
||||||
self.backbone = self._backbone
|
"""
|
||||||
|
def __init__(self, hparams, **kwargs):
|
||||||
|
distance_fn = kwargs.pop("distance_fn", omega_distance)
|
||||||
|
super().__init__(hparams, distance_fn=distance_fn, **kwargs)
|
||||||
|
relevances = torch.ones(self.hparams.input_dim, device=self.device)
|
||||||
|
self.register_parameter("_relevances", Parameter(relevances))
|
||||||
|
# Override the backbone.
|
||||||
|
self.backbone = LambdaLayer(lambda x: x @ torch.diag(self.relevances),
|
||||||
|
name="relevances")
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def relevance_profile(self):
|
def relevance_profile(self):
|
||||||
return self.relevances.detach().cpu()
|
return self.relevances.detach().cpu()
|
||||||
|
|
||||||
def _backbone(self, x):
|
|
||||||
"""Namespace hook for the visualization callbacks to work."""
|
|
||||||
return x @ torch.diag(self.relevances)
|
|
||||||
|
|
||||||
def _forward(self, x):
|
def _forward(self, x):
|
||||||
protos, _ = self.proto_layer()
|
protos, _ = self.proto_layer()
|
||||||
distances = omega_distance(x, protos, torch.diag(self.relevances))
|
distances = self.distance_layer(x, protos, torch.diag(self.relevances))
|
||||||
return distances
|
return distances
|
||||||
|
|
||||||
|
|
||||||
class GMLVQ(SiameseGLVQ):
|
class SiameseGMLVQ(SiameseGLVQ):
|
||||||
"""Generalized Matrix Learning Vector Quantization."""
|
"""Generalized Matrix Learning Vector Quantization.
|
||||||
|
|
||||||
|
Implemented as a Siamese network with a linear transformation backbone.
|
||||||
|
|
||||||
|
"""
|
||||||
def __init__(self, hparams, **kwargs):
|
def __init__(self, hparams, **kwargs):
|
||||||
super().__init__(hparams, **kwargs)
|
super().__init__(hparams, **kwargs)
|
||||||
|
# Override the backbone.
|
||||||
self.backbone = torch.nn.Linear(self.hparams.input_dim,
|
self.backbone = torch.nn.Linear(self.hparams.input_dim,
|
||||||
self.hparams.latent_dim,
|
self.hparams.latent_dim,
|
||||||
bias=False)
|
bias=False)
|
||||||
@ -230,16 +246,6 @@ class GMLVQ(SiameseGLVQ):
|
|||||||
lam = omega.T @ omega
|
lam = omega.T @ omega
|
||||||
return lam.detach().cpu()
|
return lam.detach().cpu()
|
||||||
|
|
||||||
def show_lambda(self):
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
title = "Lambda matrix"
|
|
||||||
plt.figure(title)
|
|
||||||
plt.title(title)
|
|
||||||
plt.imshow(self.lambda_matrix, cmap="gray")
|
|
||||||
plt.axis("off")
|
|
||||||
plt.colorbar()
|
|
||||||
plt.show(block=True)
|
|
||||||
|
|
||||||
def _forward(self, x):
|
def _forward(self, x):
|
||||||
protos, _ = self.proto_layer()
|
protos, _ = self.proto_layer()
|
||||||
x, protos = get_flat(x, protos)
|
x, protos = get_flat(x, protos)
|
||||||
@ -247,7 +253,7 @@ class GMLVQ(SiameseGLVQ):
|
|||||||
self.backbone.requires_grad_(self.both_path_gradients)
|
self.backbone.requires_grad_(self.both_path_gradients)
|
||||||
latent_protos = self.backbone(protos)
|
latent_protos = self.backbone(protos)
|
||||||
self.backbone.requires_grad_(True)
|
self.backbone.requires_grad_(True)
|
||||||
distances = self.distance_fn(latent_x, latent_protos)
|
distances = self.distance_layer(latent_x, latent_protos)
|
||||||
return distances
|
return distances
|
||||||
|
|
||||||
|
|
||||||
@ -263,24 +269,47 @@ class LVQMLN(SiameseGLVQ):
|
|||||||
def _forward(self, x):
|
def _forward(self, x):
|
||||||
latent_protos, _ = self.proto_layer()
|
latent_protos, _ = self.proto_layer()
|
||||||
latent_x = self.backbone(x)
|
latent_x = self.backbone(x)
|
||||||
distances = self.distance_fn(latent_x, latent_protos)
|
distances = self.distance_layer(latent_x, latent_protos)
|
||||||
return distances
|
return distances
|
||||||
|
|
||||||
|
|
||||||
class CELVQ(GLVQ):
|
class GMLVQ(GLVQ):
|
||||||
"""Cross-Entropy Learning Vector Quantization."""
|
"""Generalized Matrix Learning Vector Quantization.
|
||||||
def __init__(self, hparams, **kwargs):
|
|
||||||
super().__init__(hparams, **kwargs)
|
|
||||||
self.loss = torch.nn.CrossEntropyLoss()
|
|
||||||
|
|
||||||
def shared_step(self, batch, batch_idx, optimizer_idx=None):
|
Implemented as a regular GLVQ network that simply uses a different distance
|
||||||
x, y = batch
|
function.
|
||||||
out = self._forward(x) # [None, num_protos]
|
|
||||||
plabels = self.proto_layer.component_labels
|
"""
|
||||||
probs = -1.0 * stratified_min(out, plabels) # [None, num_classes]
|
def __init__(self, hparams, **kwargs):
|
||||||
batch_loss = self.loss(out, y.long())
|
distance_fn = kwargs.pop("distance_fn", omega_distance)
|
||||||
loss = batch_loss.sum(dim=0)
|
super().__init__(hparams, distance_fn=distance_fn, **kwargs)
|
||||||
return out, loss
|
omega = torch.randn(self.hparams.input_dim,
|
||||||
|
self.hparams.latent_dim,
|
||||||
|
device=self.device)
|
||||||
|
self.register_parameter("_omega", Parameter(omega))
|
||||||
|
|
||||||
|
def _forward(self, x):
|
||||||
|
protos, _ = self.proto_layer()
|
||||||
|
distances = self.distance_layer(x, protos, self._omega)
|
||||||
|
return distances
|
||||||
|
|
||||||
|
def extra_repr(self):
|
||||||
|
return f"(omega): (shape: {tuple(self._omega.shape)})"
|
||||||
|
|
||||||
|
|
||||||
|
class LGMLVQ(GMLVQ):
|
||||||
|
"""Localized and Generalized Matrix Learning Vector Quantization."""
|
||||||
|
def __init__(self, hparams, **kwargs):
|
||||||
|
distance_fn = kwargs.pop("distance_fn", lomega_distance)
|
||||||
|
super().__init__(hparams, distance_fn=distance_fn, **kwargs)
|
||||||
|
# Re-register `_omega` to override the one from the super class.
|
||||||
|
omega = torch.randn(
|
||||||
|
self.num_prototypes,
|
||||||
|
self.hparams.input_dim,
|
||||||
|
self.hparams.latent_dim,
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
|
self.register_parameter("_omega", Parameter(omega))
|
||||||
|
|
||||||
|
|
||||||
class GLVQ1(GLVQ):
|
class GLVQ1(GLVQ):
|
||||||
|
Loading…
Reference in New Issue
Block a user