[WIP] Update CBC implementation to use SiameseGLVQ
This commit is contained in:
parent
49f9a12b5f
commit
88a34a06ef
@ -6,13 +6,10 @@ import torch
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# Dataset
|
# Dataset
|
||||||
from sklearn.datasets import load_iris
|
train_ds = pt.datasets.Iris(dims=[0, 2])
|
||||||
x_train, y_train = load_iris(return_X_y=True)
|
|
||||||
x_train = x_train[:, [0, 2]]
|
|
||||||
train_ds = pt.datasets.NumpyDataset(x_train, y_train)
|
|
||||||
|
|
||||||
# Reproducibility
|
# Reproducibility
|
||||||
pl.utilities.seed.seed_everything(seed=2)
|
pl.utilities.seed.seed_everything(seed=3)
|
||||||
|
|
||||||
# Dataloaders
|
# Dataloaders
|
||||||
train_loader = torch.utils.data.DataLoader(train_ds,
|
train_loader = torch.utils.data.DataLoader(train_ds,
|
||||||
@ -21,18 +18,19 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
# Hyperparameters
|
# Hyperparameters
|
||||||
hparams = dict(
|
hparams = dict(
|
||||||
input_dim=x_train.shape[1],
|
distribution=[3, 2, 2],
|
||||||
nclasses=3,
|
proto_lr=0.01,
|
||||||
num_components=5,
|
bb_lr=0.01,
|
||||||
component_initializer=pt.components.SSI(train_ds, noise=0.01),
|
|
||||||
lr=0.01,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Initialize the model
|
# Initialize the model
|
||||||
model = pt.models.CBC(hparams)
|
model = pt.models.CBC(
|
||||||
|
hparams,
|
||||||
|
prototype_initializer=pt.components.SSI(train_ds, noise=0.01),
|
||||||
|
)
|
||||||
|
|
||||||
# Callbacks
|
# Callbacks
|
||||||
dvis = pt.models.VisCBC2D(data=(x_train, y_train),
|
dvis = pt.models.VisCBC2D(data=train_ds,
|
||||||
title="CBC Iris Example",
|
title="CBC Iris Example",
|
||||||
resolution=300,
|
resolution=300,
|
||||||
axis_off=True)
|
axis_off=True)
|
||||||
|
@ -5,6 +5,10 @@ from prototorch.components.components import Components
|
|||||||
from prototorch.functions.distances import euclidean_distance
|
from prototorch.functions.distances import euclidean_distance
|
||||||
from prototorch.functions.similarities import cosine_similarity
|
from prototorch.functions.similarities import cosine_similarity
|
||||||
|
|
||||||
|
from .abstract import (AbstractPrototypeModel, PrototypeImageModel,
|
||||||
|
SiamesePrototypeModel)
|
||||||
|
from .glvq import SiameseGLVQ
|
||||||
|
|
||||||
|
|
||||||
def rescaled_cosine_similarity(x, y):
|
def rescaled_cosine_similarity(x, y):
|
||||||
"""Cosine Similarity rescaled to [0, 1]."""
|
"""Cosine Similarity rescaled to [0, 1]."""
|
||||||
@ -16,9 +20,9 @@ def shift_activation(x):
|
|||||||
return (x + 1.0) / 2.0
|
return (x + 1.0) / 2.0
|
||||||
|
|
||||||
|
|
||||||
def euclidean_similarity(x, y):
|
def euclidean_similarity(x, y, beta=3):
|
||||||
d = euclidean_distance(x, y)
|
d = euclidean_distance(x, y)
|
||||||
return torch.exp(-d * 3)
|
return torch.exp(-d * beta)
|
||||||
|
|
||||||
|
|
||||||
class CosineSimilarity(torch.nn.Module):
|
class CosineSimilarity(torch.nn.Module):
|
||||||
@ -55,11 +59,12 @@ class MarginLoss(torch.nn.modules.loss._Loss):
|
|||||||
|
|
||||||
|
|
||||||
class ReasoningLayer(torch.nn.Module):
|
class ReasoningLayer(torch.nn.Module):
|
||||||
def __init__(self, n_components, n_classes, n_replicas=1):
|
def __init__(self, num_components, num_classes, n_replicas=1):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.n_replicas = n_replicas
|
self.n_replicas = n_replicas
|
||||||
self.n_classes = n_classes
|
self.num_classes = num_classes
|
||||||
probabilities_init = torch.zeros(2, 1, n_components, self.n_classes)
|
probabilities_init = torch.zeros(2, 1, num_components,
|
||||||
|
self.num_classes)
|
||||||
probabilities_init.uniform_(0.4, 0.6)
|
probabilities_init.uniform_(0.4, 0.6)
|
||||||
self.reasoning_probabilities = torch.nn.Parameter(probabilities_init)
|
self.reasoning_probabilities = torch.nn.Parameter(probabilities_init)
|
||||||
|
|
||||||
@ -81,73 +86,59 @@ class ReasoningLayer(torch.nn.Module):
|
|||||||
return probs
|
return probs
|
||||||
|
|
||||||
|
|
||||||
class CBC(pl.LightningModule):
|
class CBC(SiameseGLVQ):
|
||||||
"""Classification-By-Components."""
|
"""Classification-By-Components."""
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
hparams,
|
hparams,
|
||||||
margin=0.1,
|
margin=0.1,
|
||||||
backbone_class=torch.nn.Identity,
|
|
||||||
similarity=euclidean_similarity,
|
similarity=euclidean_similarity,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
super().__init__()
|
super().__init__(hparams, **kwargs)
|
||||||
self.save_hyperparameters(hparams)
|
|
||||||
self.margin = margin
|
self.margin = margin
|
||||||
self.component_layer = Components(self.hparams.num_components,
|
self.similarity_fn = kwargs.get("similarity_fn", euclidean_similarity)
|
||||||
self.hparams.component_initializer)
|
num_components = self.components.shape[0]
|
||||||
# self.similarity = CosineSimilarity()
|
self.reasoning_layer = ReasoningLayer(num_components=num_components,
|
||||||
self.similarity = similarity
|
num_classes=self.num_classes)
|
||||||
self.backbone = backbone_class()
|
self.component_layer = self.proto_layer
|
||||||
self.backbone_dependent = backbone_class().requires_grad_(False)
|
|
||||||
n_components = self.components.shape[0]
|
|
||||||
self.reasoning_layer = ReasoningLayer(n_components=n_components,
|
|
||||||
n_classes=self.hparams.nclasses)
|
|
||||||
self.train_acc = torchmetrics.Accuracy()
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def components(self):
|
def components(self):
|
||||||
return self.component_layer.components.detach().cpu()
|
return self.prototypes
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def reasonings(self):
|
def reasonings(self):
|
||||||
return self.reasoning_layer.reasonings.cpu()
|
return self.reasoning_layer.reasonings.cpu()
|
||||||
|
|
||||||
def configure_optimizers(self):
|
|
||||||
optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.lr)
|
|
||||||
return optimizer
|
|
||||||
|
|
||||||
def sync_backbones(self):
|
|
||||||
master_state = self.backbone.state_dict()
|
|
||||||
self.backbone_dependent.load_state_dict(master_state, strict=True)
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
self.sync_backbones()
|
components, _ = self.component_layer()
|
||||||
protos = self.component_layer()
|
|
||||||
|
|
||||||
latent_x = self.backbone(x)
|
latent_x = self.backbone(x)
|
||||||
latent_protos = self.backbone_dependent(protos)
|
self.backbone.requires_grad_(self.both_path_gradients)
|
||||||
|
latent_components = self.backbone(components)
|
||||||
detections = self.similarity(latent_x, latent_protos)
|
self.backbone.requires_grad_(True)
|
||||||
|
detections = self.similarity_fn(latent_x, latent_components)
|
||||||
probs = self.reasoning_layer(detections)
|
probs = self.reasoning_layer(detections)
|
||||||
return probs
|
return probs
|
||||||
|
|
||||||
def training_step(self, train_batch, batch_idx):
|
def shared_step(self, batch, batch_idx, optimizer_idx=None):
|
||||||
x, y = train_batch
|
x, y = batch
|
||||||
x = x.view(x.size(0), -1)
|
# x = x.view(x.size(0), -1)
|
||||||
y_pred = self(x)
|
y_pred = self(x)
|
||||||
nclasses = self.reasoning_layer.n_classes
|
nclasses = self.reasoning_layer.num_classes
|
||||||
y_true = torch.nn.functional.one_hot(y.long(), num_classes=nclasses)
|
y_true = torch.nn.functional.one_hot(y.long(), num_classes=nclasses)
|
||||||
loss = MarginLoss(self.margin)(y_pred, y_true).mean(dim=0)
|
loss = MarginLoss(self.margin)(y_pred, y_true).mean(dim=0)
|
||||||
self.log("train_loss", loss)
|
return y_pred, loss
|
||||||
self.train_acc(y_pred, y_true)
|
|
||||||
self.log(
|
def training_step(self, batch, batch_idx, optimizer_idx=None):
|
||||||
"acc",
|
y_pred, train_loss = self.shared_step(batch, batch_idx, optimizer_idx)
|
||||||
self.train_acc,
|
preds = torch.argmax(y_pred, dim=1)
|
||||||
on_step=False,
|
self.acc_metric(preds.int(), batch[1].int())
|
||||||
on_epoch=True,
|
self.log("train_acc",
|
||||||
prog_bar=True,
|
self.acc_metric,
|
||||||
logger=True,
|
on_step=False,
|
||||||
)
|
on_epoch=True,
|
||||||
return loss
|
prog_bar=True,
|
||||||
|
logger=True)
|
||||||
|
return train_loss
|
||||||
|
|
||||||
def predict(self, x):
|
def predict(self, x):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
@ -49,11 +49,21 @@ class GLVQ(AbstractPrototypeModel):
|
|||||||
def prototype_labels(self):
|
def prototype_labels(self):
|
||||||
return self.proto_layer.component_labels.detach().cpu()
|
return self.proto_layer.component_labels.detach().cpu()
|
||||||
|
|
||||||
def forward(self, x):
|
@property
|
||||||
|
def num_classes(self):
|
||||||
|
return len(self.proto_layer.distribution)
|
||||||
|
|
||||||
|
def _forward(self, x):
|
||||||
protos, _ = self.proto_layer()
|
protos, _ = self.proto_layer()
|
||||||
distances = self.distance_fn(x, protos)
|
distances = self.distance_fn(x, protos)
|
||||||
return distances
|
return distances
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
distances = self._forward(x)
|
||||||
|
y_pred = self.predict_from_distances(distances)
|
||||||
|
y_pred = torch.eye(self.num_classes, device=self.device)[y_pred.int()]
|
||||||
|
return y_pred
|
||||||
|
|
||||||
def predict_from_distances(self, distances):
|
def predict_from_distances(self, distances):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
plabels = self.proto_layer.component_labels
|
plabels = self.proto_layer.component_labels
|
||||||
@ -62,7 +72,7 @@ class GLVQ(AbstractPrototypeModel):
|
|||||||
|
|
||||||
def predict(self, x):
|
def predict(self, x):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
distances = self(x)
|
distances = self._forward(x)
|
||||||
y_pred = self.predict_from_distances(distances)
|
y_pred = self.predict_from_distances(distances)
|
||||||
return y_pred
|
return y_pred
|
||||||
|
|
||||||
@ -80,7 +90,7 @@ class GLVQ(AbstractPrototypeModel):
|
|||||||
|
|
||||||
def shared_step(self, batch, batch_idx, optimizer_idx=None):
|
def shared_step(self, batch, batch_idx, optimizer_idx=None):
|
||||||
x, y = batch
|
x, y = batch
|
||||||
out = self(x)
|
out = self._forward(x)
|
||||||
plabels = self.proto_layer.component_labels
|
plabels = self.proto_layer.component_labels
|
||||||
mu = self.loss(out, y, prototype_labels=plabels)
|
mu = self.loss(out, y, prototype_labels=plabels)
|
||||||
batch_loss = self.transfer_fn(mu, beta=self.hparams.transfer_beta)
|
batch_loss = self.transfer_fn(mu, beta=self.hparams.transfer_beta)
|
||||||
@ -89,6 +99,7 @@ class GLVQ(AbstractPrototypeModel):
|
|||||||
|
|
||||||
def training_step(self, batch, batch_idx, optimizer_idx=None):
|
def training_step(self, batch, batch_idx, optimizer_idx=None):
|
||||||
out, train_loss = self.shared_step(batch, batch_idx, optimizer_idx)
|
out, train_loss = self.shared_step(batch, batch_idx, optimizer_idx)
|
||||||
|
self.log("train_loss", train_loss)
|
||||||
self.log_acc(out, batch[-1], tag="train_acc")
|
self.log_acc(out, batch[-1], tag="train_acc")
|
||||||
return train_loss
|
return train_loss
|
||||||
|
|
||||||
@ -137,23 +148,22 @@ class SiameseGLVQ(SiamesePrototypeModel, GLVQ):
|
|||||||
self.both_path_gradients = both_path_gradients
|
self.both_path_gradients = both_path_gradients
|
||||||
self.distance_fn = kwargs.get("distance_fn", sed)
|
self.distance_fn = kwargs.get("distance_fn", sed)
|
||||||
|
|
||||||
def forward(self, x):
|
def _forward(self, x):
|
||||||
protos, _ = self.proto_layer()
|
protos, _ = self.proto_layer()
|
||||||
latent_x = self.backbone(x)
|
latent_x = self.backbone(x)
|
||||||
self.backbone.requires_grad_(self.both_path_gradients)
|
self.backbone.requires_grad_(self.both_path_gradients)
|
||||||
latent_protos = self.backbone(protos)
|
latent_protos = self.backbone(protos)
|
||||||
self.backbone.requires_grad_(True)
|
self.backbone.requires_grad_(True)
|
||||||
dis = self.distance_fn(latent_x, latent_protos)
|
distances = self.distance_fn(latent_x, latent_protos)
|
||||||
return dis
|
return distances
|
||||||
|
|
||||||
|
|
||||||
class GRLVQ(SiamesePrototypeModel, GLVQ):
|
class GRLVQ(SiameseGLVQ):
|
||||||
"""Generalized Relevance Learning Vector Quantization."""
|
"""Generalized Relevance Learning Vector Quantization."""
|
||||||
def __init__(self, hparams, **kwargs):
|
def __init__(self, hparams, **kwargs):
|
||||||
super().__init__(hparams, **kwargs)
|
super().__init__(hparams, **kwargs)
|
||||||
self.relevances = torch.nn.parameter.Parameter(
|
self.relevances = torch.nn.parameter.Parameter(
|
||||||
torch.ones(self.hparams.input_dim))
|
torch.ones(self.hparams.input_dim))
|
||||||
self.distance_fn = kwargs.get("distance_fn", sed)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def relevance_profile(self):
|
def relevance_profile(self):
|
||||||
@ -163,20 +173,19 @@ class GRLVQ(SiamesePrototypeModel, GLVQ):
|
|||||||
"""Namespace hook for the visualization callbacks to work."""
|
"""Namespace hook for the visualization callbacks to work."""
|
||||||
return x @ torch.diag(self.relevances)
|
return x @ torch.diag(self.relevances)
|
||||||
|
|
||||||
def forward(self, x):
|
def _forward(self, x):
|
||||||
protos, _ = self.proto_layer()
|
protos, _ = self.proto_layer()
|
||||||
dis = omega_distance(x, protos, torch.diag(self.relevances))
|
distances = omega_distance(x, protos, torch.diag(self.relevances))
|
||||||
return dis
|
return distances
|
||||||
|
|
||||||
|
|
||||||
class GMLVQ(SiamesePrototypeModel, GLVQ):
|
class GMLVQ(SiameseGLVQ):
|
||||||
"""Generalized Matrix Learning Vector Quantization."""
|
"""Generalized Matrix Learning Vector Quantization."""
|
||||||
def __init__(self, hparams, **kwargs):
|
def __init__(self, hparams, **kwargs):
|
||||||
super().__init__(hparams, **kwargs)
|
super().__init__(hparams, **kwargs)
|
||||||
self.backbone = torch.nn.Linear(self.hparams.input_dim,
|
self.backbone = torch.nn.Linear(self.hparams.input_dim,
|
||||||
self.hparams.latent_dim,
|
self.hparams.latent_dim,
|
||||||
bias=False)
|
bias=False)
|
||||||
self.distance_fn = kwargs.get("distance_fn", sed)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def omega_matrix(self):
|
def omega_matrix(self):
|
||||||
@ -198,16 +207,18 @@ class GMLVQ(SiamesePrototypeModel, GLVQ):
|
|||||||
plt.colorbar()
|
plt.colorbar()
|
||||||
plt.show(block=True)
|
plt.show(block=True)
|
||||||
|
|
||||||
def forward(self, x):
|
def _forward(self, x):
|
||||||
protos, _ = self.proto_layer()
|
protos, _ = self.proto_layer()
|
||||||
x, protos = get_flat(x, protos)
|
x, protos = get_flat(x, protos)
|
||||||
latent_x = self.backbone(x)
|
latent_x = self.backbone(x)
|
||||||
|
self.backbone.requires_grad_(self.both_path_gradients)
|
||||||
latent_protos = self.backbone(protos)
|
latent_protos = self.backbone(protos)
|
||||||
dis = self.distance_fn(latent_x, latent_protos)
|
self.backbone.requires_grad_(True)
|
||||||
return dis
|
distances = self.distance_fn(latent_x, latent_protos)
|
||||||
|
return distances
|
||||||
|
|
||||||
|
|
||||||
class LVQMLN(SiamesePrototypeModel, GLVQ):
|
class LVQMLN(SiameseGLVQ):
|
||||||
"""Learning Vector Quantization Multi-Layer Network.
|
"""Learning Vector Quantization Multi-Layer Network.
|
||||||
|
|
||||||
GLVQ model that applies an arbitrary transformation on the inputs, BUT NOT
|
GLVQ model that applies an arbitrary transformation on the inputs, BUT NOT
|
||||||
@ -216,17 +227,11 @@ class LVQMLN(SiamesePrototypeModel, GLVQ):
|
|||||||
rather in the embedding space.
|
rather in the embedding space.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
def __init__(self, hparams, backbone=torch.nn.Identity(), **kwargs):
|
def _forward(self, x):
|
||||||
super().__init__(hparams, **kwargs)
|
|
||||||
self.backbone = backbone
|
|
||||||
|
|
||||||
self.distance_fn = kwargs.get("distance_fn", sed)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
latent_protos, _ = self.proto_layer()
|
latent_protos, _ = self.proto_layer()
|
||||||
latent_x = self.backbone(x)
|
latent_x = self.backbone(x)
|
||||||
dis = self.distance_fn(latent_x, latent_protos)
|
distances = self.distance_fn(latent_x, latent_protos)
|
||||||
return dis
|
return distances
|
||||||
|
|
||||||
|
|
||||||
class NonGradientGLVQ(GLVQ):
|
class NonGradientGLVQ(GLVQ):
|
||||||
@ -244,7 +249,7 @@ class LVQ1(NonGradientGLVQ):
|
|||||||
plabels = self.proto_layer.component_labels
|
plabels = self.proto_layer.component_labels
|
||||||
|
|
||||||
x, y = train_batch
|
x, y = train_batch
|
||||||
dis = self(x)
|
dis = self._forward(x)
|
||||||
# TODO Vectorized implementation
|
# TODO Vectorized implementation
|
||||||
|
|
||||||
for xi, yi in zip(x, y):
|
for xi, yi in zip(x, y):
|
||||||
@ -272,7 +277,7 @@ class LVQ21(NonGradientGLVQ):
|
|||||||
plabels = self.proto_layer.component_labels
|
plabels = self.proto_layer.component_labels
|
||||||
|
|
||||||
x, y = train_batch
|
x, y = train_batch
|
||||||
dis = self(x)
|
dis = self._forward(x)
|
||||||
# TODO Vectorized implementation
|
# TODO Vectorized implementation
|
||||||
|
|
||||||
for xi, yi in zip(x, y):
|
for xi, yi in zip(x, y):
|
||||||
|
Loading…
Reference in New Issue
Block a user