[WIP] Update CBC implementation to use SiameseGLVQ
This commit is contained in:
parent
49f9a12b5f
commit
88a34a06ef
@ -6,13 +6,10 @@ import torch
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Dataset
|
||||
from sklearn.datasets import load_iris
|
||||
x_train, y_train = load_iris(return_X_y=True)
|
||||
x_train = x_train[:, [0, 2]]
|
||||
train_ds = pt.datasets.NumpyDataset(x_train, y_train)
|
||||
train_ds = pt.datasets.Iris(dims=[0, 2])
|
||||
|
||||
# Reproducibility
|
||||
pl.utilities.seed.seed_everything(seed=2)
|
||||
pl.utilities.seed.seed_everything(seed=3)
|
||||
|
||||
# Dataloaders
|
||||
train_loader = torch.utils.data.DataLoader(train_ds,
|
||||
@ -21,18 +18,19 @@ if __name__ == "__main__":
|
||||
|
||||
# Hyperparameters
|
||||
hparams = dict(
|
||||
input_dim=x_train.shape[1],
|
||||
nclasses=3,
|
||||
num_components=5,
|
||||
component_initializer=pt.components.SSI(train_ds, noise=0.01),
|
||||
lr=0.01,
|
||||
distribution=[3, 2, 2],
|
||||
proto_lr=0.01,
|
||||
bb_lr=0.01,
|
||||
)
|
||||
|
||||
# Initialize the model
|
||||
model = pt.models.CBC(hparams)
|
||||
model = pt.models.CBC(
|
||||
hparams,
|
||||
prototype_initializer=pt.components.SSI(train_ds, noise=0.01),
|
||||
)
|
||||
|
||||
# Callbacks
|
||||
dvis = pt.models.VisCBC2D(data=(x_train, y_train),
|
||||
dvis = pt.models.VisCBC2D(data=train_ds,
|
||||
title="CBC Iris Example",
|
||||
resolution=300,
|
||||
axis_off=True)
|
||||
|
@ -5,6 +5,10 @@ from prototorch.components.components import Components
|
||||
from prototorch.functions.distances import euclidean_distance
|
||||
from prototorch.functions.similarities import cosine_similarity
|
||||
|
||||
from .abstract import (AbstractPrototypeModel, PrototypeImageModel,
|
||||
SiamesePrototypeModel)
|
||||
from .glvq import SiameseGLVQ
|
||||
|
||||
|
||||
def rescaled_cosine_similarity(x, y):
|
||||
"""Cosine Similarity rescaled to [0, 1]."""
|
||||
@ -16,9 +20,9 @@ def shift_activation(x):
|
||||
return (x + 1.0) / 2.0
|
||||
|
||||
|
||||
def euclidean_similarity(x, y):
|
||||
def euclidean_similarity(x, y, beta=3):
|
||||
d = euclidean_distance(x, y)
|
||||
return torch.exp(-d * 3)
|
||||
return torch.exp(-d * beta)
|
||||
|
||||
|
||||
class CosineSimilarity(torch.nn.Module):
|
||||
@ -55,11 +59,12 @@ class MarginLoss(torch.nn.modules.loss._Loss):
|
||||
|
||||
|
||||
class ReasoningLayer(torch.nn.Module):
|
||||
def __init__(self, n_components, n_classes, n_replicas=1):
|
||||
def __init__(self, num_components, num_classes, n_replicas=1):
|
||||
super().__init__()
|
||||
self.n_replicas = n_replicas
|
||||
self.n_classes = n_classes
|
||||
probabilities_init = torch.zeros(2, 1, n_components, self.n_classes)
|
||||
self.num_classes = num_classes
|
||||
probabilities_init = torch.zeros(2, 1, num_components,
|
||||
self.num_classes)
|
||||
probabilities_init.uniform_(0.4, 0.6)
|
||||
self.reasoning_probabilities = torch.nn.Parameter(probabilities_init)
|
||||
|
||||
@ -81,73 +86,59 @@ class ReasoningLayer(torch.nn.Module):
|
||||
return probs
|
||||
|
||||
|
||||
class CBC(pl.LightningModule):
|
||||
class CBC(SiameseGLVQ):
|
||||
"""Classification-By-Components."""
|
||||
def __init__(self,
|
||||
hparams,
|
||||
margin=0.1,
|
||||
backbone_class=torch.nn.Identity,
|
||||
similarity=euclidean_similarity,
|
||||
**kwargs):
|
||||
super().__init__()
|
||||
self.save_hyperparameters(hparams)
|
||||
super().__init__(hparams, **kwargs)
|
||||
self.margin = margin
|
||||
self.component_layer = Components(self.hparams.num_components,
|
||||
self.hparams.component_initializer)
|
||||
# self.similarity = CosineSimilarity()
|
||||
self.similarity = similarity
|
||||
self.backbone = backbone_class()
|
||||
self.backbone_dependent = backbone_class().requires_grad_(False)
|
||||
n_components = self.components.shape[0]
|
||||
self.reasoning_layer = ReasoningLayer(n_components=n_components,
|
||||
n_classes=self.hparams.nclasses)
|
||||
self.train_acc = torchmetrics.Accuracy()
|
||||
self.similarity_fn = kwargs.get("similarity_fn", euclidean_similarity)
|
||||
num_components = self.components.shape[0]
|
||||
self.reasoning_layer = ReasoningLayer(num_components=num_components,
|
||||
num_classes=self.num_classes)
|
||||
self.component_layer = self.proto_layer
|
||||
|
||||
@property
|
||||
def components(self):
|
||||
return self.component_layer.components.detach().cpu()
|
||||
return self.prototypes
|
||||
|
||||
@property
|
||||
def reasonings(self):
|
||||
return self.reasoning_layer.reasonings.cpu()
|
||||
|
||||
def configure_optimizers(self):
|
||||
optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.lr)
|
||||
return optimizer
|
||||
|
||||
def sync_backbones(self):
|
||||
master_state = self.backbone.state_dict()
|
||||
self.backbone_dependent.load_state_dict(master_state, strict=True)
|
||||
|
||||
def forward(self, x):
|
||||
self.sync_backbones()
|
||||
protos = self.component_layer()
|
||||
|
||||
components, _ = self.component_layer()
|
||||
latent_x = self.backbone(x)
|
||||
latent_protos = self.backbone_dependent(protos)
|
||||
|
||||
detections = self.similarity(latent_x, latent_protos)
|
||||
self.backbone.requires_grad_(self.both_path_gradients)
|
||||
latent_components = self.backbone(components)
|
||||
self.backbone.requires_grad_(True)
|
||||
detections = self.similarity_fn(latent_x, latent_components)
|
||||
probs = self.reasoning_layer(detections)
|
||||
return probs
|
||||
|
||||
def training_step(self, train_batch, batch_idx):
|
||||
x, y = train_batch
|
||||
x = x.view(x.size(0), -1)
|
||||
def shared_step(self, batch, batch_idx, optimizer_idx=None):
|
||||
x, y = batch
|
||||
# x = x.view(x.size(0), -1)
|
||||
y_pred = self(x)
|
||||
nclasses = self.reasoning_layer.n_classes
|
||||
nclasses = self.reasoning_layer.num_classes
|
||||
y_true = torch.nn.functional.one_hot(y.long(), num_classes=nclasses)
|
||||
loss = MarginLoss(self.margin)(y_pred, y_true).mean(dim=0)
|
||||
self.log("train_loss", loss)
|
||||
self.train_acc(y_pred, y_true)
|
||||
self.log(
|
||||
"acc",
|
||||
self.train_acc,
|
||||
on_step=False,
|
||||
on_epoch=True,
|
||||
prog_bar=True,
|
||||
logger=True,
|
||||
)
|
||||
return loss
|
||||
return y_pred, loss
|
||||
|
||||
def training_step(self, batch, batch_idx, optimizer_idx=None):
|
||||
y_pred, train_loss = self.shared_step(batch, batch_idx, optimizer_idx)
|
||||
preds = torch.argmax(y_pred, dim=1)
|
||||
self.acc_metric(preds.int(), batch[1].int())
|
||||
self.log("train_acc",
|
||||
self.acc_metric,
|
||||
on_step=False,
|
||||
on_epoch=True,
|
||||
prog_bar=True,
|
||||
logger=True)
|
||||
return train_loss
|
||||
|
||||
def predict(self, x):
|
||||
with torch.no_grad():
|
||||
|
@ -49,11 +49,21 @@ class GLVQ(AbstractPrototypeModel):
|
||||
def prototype_labels(self):
|
||||
return self.proto_layer.component_labels.detach().cpu()
|
||||
|
||||
def forward(self, x):
|
||||
@property
|
||||
def num_classes(self):
|
||||
return len(self.proto_layer.distribution)
|
||||
|
||||
def _forward(self, x):
|
||||
protos, _ = self.proto_layer()
|
||||
distances = self.distance_fn(x, protos)
|
||||
return distances
|
||||
|
||||
def forward(self, x):
|
||||
distances = self._forward(x)
|
||||
y_pred = self.predict_from_distances(distances)
|
||||
y_pred = torch.eye(self.num_classes, device=self.device)[y_pred.int()]
|
||||
return y_pred
|
||||
|
||||
def predict_from_distances(self, distances):
|
||||
with torch.no_grad():
|
||||
plabels = self.proto_layer.component_labels
|
||||
@ -62,7 +72,7 @@ class GLVQ(AbstractPrototypeModel):
|
||||
|
||||
def predict(self, x):
|
||||
with torch.no_grad():
|
||||
distances = self(x)
|
||||
distances = self._forward(x)
|
||||
y_pred = self.predict_from_distances(distances)
|
||||
return y_pred
|
||||
|
||||
@ -80,7 +90,7 @@ class GLVQ(AbstractPrototypeModel):
|
||||
|
||||
def shared_step(self, batch, batch_idx, optimizer_idx=None):
|
||||
x, y = batch
|
||||
out = self(x)
|
||||
out = self._forward(x)
|
||||
plabels = self.proto_layer.component_labels
|
||||
mu = self.loss(out, y, prototype_labels=plabels)
|
||||
batch_loss = self.transfer_fn(mu, beta=self.hparams.transfer_beta)
|
||||
@ -89,6 +99,7 @@ class GLVQ(AbstractPrototypeModel):
|
||||
|
||||
def training_step(self, batch, batch_idx, optimizer_idx=None):
|
||||
out, train_loss = self.shared_step(batch, batch_idx, optimizer_idx)
|
||||
self.log("train_loss", train_loss)
|
||||
self.log_acc(out, batch[-1], tag="train_acc")
|
||||
return train_loss
|
||||
|
||||
@ -137,23 +148,22 @@ class SiameseGLVQ(SiamesePrototypeModel, GLVQ):
|
||||
self.both_path_gradients = both_path_gradients
|
||||
self.distance_fn = kwargs.get("distance_fn", sed)
|
||||
|
||||
def forward(self, x):
|
||||
def _forward(self, x):
|
||||
protos, _ = self.proto_layer()
|
||||
latent_x = self.backbone(x)
|
||||
self.backbone.requires_grad_(self.both_path_gradients)
|
||||
latent_protos = self.backbone(protos)
|
||||
self.backbone.requires_grad_(True)
|
||||
dis = self.distance_fn(latent_x, latent_protos)
|
||||
return dis
|
||||
distances = self.distance_fn(latent_x, latent_protos)
|
||||
return distances
|
||||
|
||||
|
||||
class GRLVQ(SiamesePrototypeModel, GLVQ):
|
||||
class GRLVQ(SiameseGLVQ):
|
||||
"""Generalized Relevance Learning Vector Quantization."""
|
||||
def __init__(self, hparams, **kwargs):
|
||||
super().__init__(hparams, **kwargs)
|
||||
self.relevances = torch.nn.parameter.Parameter(
|
||||
torch.ones(self.hparams.input_dim))
|
||||
self.distance_fn = kwargs.get("distance_fn", sed)
|
||||
|
||||
@property
|
||||
def relevance_profile(self):
|
||||
@ -163,20 +173,19 @@ class GRLVQ(SiamesePrototypeModel, GLVQ):
|
||||
"""Namespace hook for the visualization callbacks to work."""
|
||||
return x @ torch.diag(self.relevances)
|
||||
|
||||
def forward(self, x):
|
||||
def _forward(self, x):
|
||||
protos, _ = self.proto_layer()
|
||||
dis = omega_distance(x, protos, torch.diag(self.relevances))
|
||||
return dis
|
||||
distances = omega_distance(x, protos, torch.diag(self.relevances))
|
||||
return distances
|
||||
|
||||
|
||||
class GMLVQ(SiamesePrototypeModel, GLVQ):
|
||||
class GMLVQ(SiameseGLVQ):
|
||||
"""Generalized Matrix Learning Vector Quantization."""
|
||||
def __init__(self, hparams, **kwargs):
|
||||
super().__init__(hparams, **kwargs)
|
||||
self.backbone = torch.nn.Linear(self.hparams.input_dim,
|
||||
self.hparams.latent_dim,
|
||||
bias=False)
|
||||
self.distance_fn = kwargs.get("distance_fn", sed)
|
||||
|
||||
@property
|
||||
def omega_matrix(self):
|
||||
@ -198,16 +207,18 @@ class GMLVQ(SiamesePrototypeModel, GLVQ):
|
||||
plt.colorbar()
|
||||
plt.show(block=True)
|
||||
|
||||
def forward(self, x):
|
||||
def _forward(self, x):
|
||||
protos, _ = self.proto_layer()
|
||||
x, protos = get_flat(x, protos)
|
||||
latent_x = self.backbone(x)
|
||||
self.backbone.requires_grad_(self.both_path_gradients)
|
||||
latent_protos = self.backbone(protos)
|
||||
dis = self.distance_fn(latent_x, latent_protos)
|
||||
return dis
|
||||
self.backbone.requires_grad_(True)
|
||||
distances = self.distance_fn(latent_x, latent_protos)
|
||||
return distances
|
||||
|
||||
|
||||
class LVQMLN(SiamesePrototypeModel, GLVQ):
|
||||
class LVQMLN(SiameseGLVQ):
|
||||
"""Learning Vector Quantization Multi-Layer Network.
|
||||
|
||||
GLVQ model that applies an arbitrary transformation on the inputs, BUT NOT
|
||||
@ -216,17 +227,11 @@ class LVQMLN(SiamesePrototypeModel, GLVQ):
|
||||
rather in the embedding space.
|
||||
|
||||
"""
|
||||
def __init__(self, hparams, backbone=torch.nn.Identity(), **kwargs):
|
||||
super().__init__(hparams, **kwargs)
|
||||
self.backbone = backbone
|
||||
|
||||
self.distance_fn = kwargs.get("distance_fn", sed)
|
||||
|
||||
def forward(self, x):
|
||||
def _forward(self, x):
|
||||
latent_protos, _ = self.proto_layer()
|
||||
latent_x = self.backbone(x)
|
||||
dis = self.distance_fn(latent_x, latent_protos)
|
||||
return dis
|
||||
distances = self.distance_fn(latent_x, latent_protos)
|
||||
return distances
|
||||
|
||||
|
||||
class NonGradientGLVQ(GLVQ):
|
||||
@ -244,7 +249,7 @@ class LVQ1(NonGradientGLVQ):
|
||||
plabels = self.proto_layer.component_labels
|
||||
|
||||
x, y = train_batch
|
||||
dis = self(x)
|
||||
dis = self._forward(x)
|
||||
# TODO Vectorized implementation
|
||||
|
||||
for xi, yi in zip(x, y):
|
||||
@ -272,7 +277,7 @@ class LVQ21(NonGradientGLVQ):
|
||||
plabels = self.proto_layer.component_labels
|
||||
|
||||
x, y = train_batch
|
||||
dis = self(x)
|
||||
dis = self._forward(x)
|
||||
# TODO Vectorized implementation
|
||||
|
||||
for xi, yi in zip(x, y):
|
||||
|
Loading…
Reference in New Issue
Block a user