[WIP] Update CBC implementation to use SiameseGLVQ

This commit is contained in:
Jensun Ravichandran 2021-05-20 17:36:00 +02:00
parent 49f9a12b5f
commit 88a34a06ef
3 changed files with 83 additions and 89 deletions

View File

@ -6,13 +6,10 @@ import torch
if __name__ == "__main__":
# Dataset
from sklearn.datasets import load_iris
x_train, y_train = load_iris(return_X_y=True)
x_train = x_train[:, [0, 2]]
train_ds = pt.datasets.NumpyDataset(x_train, y_train)
train_ds = pt.datasets.Iris(dims=[0, 2])
# Reproducibility
pl.utilities.seed.seed_everything(seed=2)
pl.utilities.seed.seed_everything(seed=3)
# Dataloaders
train_loader = torch.utils.data.DataLoader(train_ds,
@ -21,18 +18,19 @@ if __name__ == "__main__":
# Hyperparameters
hparams = dict(
input_dim=x_train.shape[1],
nclasses=3,
num_components=5,
component_initializer=pt.components.SSI(train_ds, noise=0.01),
lr=0.01,
distribution=[3, 2, 2],
proto_lr=0.01,
bb_lr=0.01,
)
# Initialize the model
model = pt.models.CBC(hparams)
model = pt.models.CBC(
hparams,
prototype_initializer=pt.components.SSI(train_ds, noise=0.01),
)
# Callbacks
dvis = pt.models.VisCBC2D(data=(x_train, y_train),
dvis = pt.models.VisCBC2D(data=train_ds,
title="CBC Iris Example",
resolution=300,
axis_off=True)

View File

@ -5,6 +5,10 @@ from prototorch.components.components import Components
from prototorch.functions.distances import euclidean_distance
from prototorch.functions.similarities import cosine_similarity
from .abstract import (AbstractPrototypeModel, PrototypeImageModel,
SiamesePrototypeModel)
from .glvq import SiameseGLVQ
def rescaled_cosine_similarity(x, y):
"""Cosine Similarity rescaled to [0, 1]."""
@ -16,9 +20,9 @@ def shift_activation(x):
return (x + 1.0) / 2.0
def euclidean_similarity(x, y):
def euclidean_similarity(x, y, beta=3):
d = euclidean_distance(x, y)
return torch.exp(-d * 3)
return torch.exp(-d * beta)
class CosineSimilarity(torch.nn.Module):
@ -55,11 +59,12 @@ class MarginLoss(torch.nn.modules.loss._Loss):
class ReasoningLayer(torch.nn.Module):
def __init__(self, n_components, n_classes, n_replicas=1):
def __init__(self, num_components, num_classes, n_replicas=1):
super().__init__()
self.n_replicas = n_replicas
self.n_classes = n_classes
probabilities_init = torch.zeros(2, 1, n_components, self.n_classes)
self.num_classes = num_classes
probabilities_init = torch.zeros(2, 1, num_components,
self.num_classes)
probabilities_init.uniform_(0.4, 0.6)
self.reasoning_probabilities = torch.nn.Parameter(probabilities_init)
@ -81,73 +86,59 @@ class ReasoningLayer(torch.nn.Module):
return probs
class CBC(pl.LightningModule):
class CBC(SiameseGLVQ):
"""Classification-By-Components."""
def __init__(self,
hparams,
margin=0.1,
backbone_class=torch.nn.Identity,
similarity=euclidean_similarity,
**kwargs):
super().__init__()
self.save_hyperparameters(hparams)
super().__init__(hparams, **kwargs)
self.margin = margin
self.component_layer = Components(self.hparams.num_components,
self.hparams.component_initializer)
# self.similarity = CosineSimilarity()
self.similarity = similarity
self.backbone = backbone_class()
self.backbone_dependent = backbone_class().requires_grad_(False)
n_components = self.components.shape[0]
self.reasoning_layer = ReasoningLayer(n_components=n_components,
n_classes=self.hparams.nclasses)
self.train_acc = torchmetrics.Accuracy()
self.similarity_fn = kwargs.get("similarity_fn", euclidean_similarity)
num_components = self.components.shape[0]
self.reasoning_layer = ReasoningLayer(num_components=num_components,
num_classes=self.num_classes)
self.component_layer = self.proto_layer
@property
def components(self):
return self.component_layer.components.detach().cpu()
return self.prototypes
@property
def reasonings(self):
return self.reasoning_layer.reasonings.cpu()
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.lr)
return optimizer
def sync_backbones(self):
master_state = self.backbone.state_dict()
self.backbone_dependent.load_state_dict(master_state, strict=True)
def forward(self, x):
self.sync_backbones()
protos = self.component_layer()
components, _ = self.component_layer()
latent_x = self.backbone(x)
latent_protos = self.backbone_dependent(protos)
detections = self.similarity(latent_x, latent_protos)
self.backbone.requires_grad_(self.both_path_gradients)
latent_components = self.backbone(components)
self.backbone.requires_grad_(True)
detections = self.similarity_fn(latent_x, latent_components)
probs = self.reasoning_layer(detections)
return probs
def training_step(self, train_batch, batch_idx):
x, y = train_batch
x = x.view(x.size(0), -1)
def shared_step(self, batch, batch_idx, optimizer_idx=None):
x, y = batch
# x = x.view(x.size(0), -1)
y_pred = self(x)
nclasses = self.reasoning_layer.n_classes
nclasses = self.reasoning_layer.num_classes
y_true = torch.nn.functional.one_hot(y.long(), num_classes=nclasses)
loss = MarginLoss(self.margin)(y_pred, y_true).mean(dim=0)
self.log("train_loss", loss)
self.train_acc(y_pred, y_true)
self.log(
"acc",
self.train_acc,
return y_pred, loss
def training_step(self, batch, batch_idx, optimizer_idx=None):
y_pred, train_loss = self.shared_step(batch, batch_idx, optimizer_idx)
preds = torch.argmax(y_pred, dim=1)
self.acc_metric(preds.int(), batch[1].int())
self.log("train_acc",
self.acc_metric,
on_step=False,
on_epoch=True,
prog_bar=True,
logger=True,
)
return loss
logger=True)
return train_loss
def predict(self, x):
with torch.no_grad():

View File

@ -49,11 +49,21 @@ class GLVQ(AbstractPrototypeModel):
def prototype_labels(self):
return self.proto_layer.component_labels.detach().cpu()
def forward(self, x):
@property
def num_classes(self):
return len(self.proto_layer.distribution)
def _forward(self, x):
protos, _ = self.proto_layer()
distances = self.distance_fn(x, protos)
return distances
def forward(self, x):
distances = self._forward(x)
y_pred = self.predict_from_distances(distances)
y_pred = torch.eye(self.num_classes, device=self.device)[y_pred.int()]
return y_pred
def predict_from_distances(self, distances):
with torch.no_grad():
plabels = self.proto_layer.component_labels
@ -62,7 +72,7 @@ class GLVQ(AbstractPrototypeModel):
def predict(self, x):
with torch.no_grad():
distances = self(x)
distances = self._forward(x)
y_pred = self.predict_from_distances(distances)
return y_pred
@ -80,7 +90,7 @@ class GLVQ(AbstractPrototypeModel):
def shared_step(self, batch, batch_idx, optimizer_idx=None):
x, y = batch
out = self(x)
out = self._forward(x)
plabels = self.proto_layer.component_labels
mu = self.loss(out, y, prototype_labels=plabels)
batch_loss = self.transfer_fn(mu, beta=self.hparams.transfer_beta)
@ -89,6 +99,7 @@ class GLVQ(AbstractPrototypeModel):
def training_step(self, batch, batch_idx, optimizer_idx=None):
out, train_loss = self.shared_step(batch, batch_idx, optimizer_idx)
self.log("train_loss", train_loss)
self.log_acc(out, batch[-1], tag="train_acc")
return train_loss
@ -137,23 +148,22 @@ class SiameseGLVQ(SiamesePrototypeModel, GLVQ):
self.both_path_gradients = both_path_gradients
self.distance_fn = kwargs.get("distance_fn", sed)
def forward(self, x):
def _forward(self, x):
protos, _ = self.proto_layer()
latent_x = self.backbone(x)
self.backbone.requires_grad_(self.both_path_gradients)
latent_protos = self.backbone(protos)
self.backbone.requires_grad_(True)
dis = self.distance_fn(latent_x, latent_protos)
return dis
distances = self.distance_fn(latent_x, latent_protos)
return distances
class GRLVQ(SiamesePrototypeModel, GLVQ):
class GRLVQ(SiameseGLVQ):
"""Generalized Relevance Learning Vector Quantization."""
def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs)
self.relevances = torch.nn.parameter.Parameter(
torch.ones(self.hparams.input_dim))
self.distance_fn = kwargs.get("distance_fn", sed)
@property
def relevance_profile(self):
@ -163,20 +173,19 @@ class GRLVQ(SiamesePrototypeModel, GLVQ):
"""Namespace hook for the visualization callbacks to work."""
return x @ torch.diag(self.relevances)
def forward(self, x):
def _forward(self, x):
protos, _ = self.proto_layer()
dis = omega_distance(x, protos, torch.diag(self.relevances))
return dis
distances = omega_distance(x, protos, torch.diag(self.relevances))
return distances
class GMLVQ(SiamesePrototypeModel, GLVQ):
class GMLVQ(SiameseGLVQ):
"""Generalized Matrix Learning Vector Quantization."""
def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs)
self.backbone = torch.nn.Linear(self.hparams.input_dim,
self.hparams.latent_dim,
bias=False)
self.distance_fn = kwargs.get("distance_fn", sed)
@property
def omega_matrix(self):
@ -198,16 +207,18 @@ class GMLVQ(SiamesePrototypeModel, GLVQ):
plt.colorbar()
plt.show(block=True)
def forward(self, x):
def _forward(self, x):
protos, _ = self.proto_layer()
x, protos = get_flat(x, protos)
latent_x = self.backbone(x)
self.backbone.requires_grad_(self.both_path_gradients)
latent_protos = self.backbone(protos)
dis = self.distance_fn(latent_x, latent_protos)
return dis
self.backbone.requires_grad_(True)
distances = self.distance_fn(latent_x, latent_protos)
return distances
class LVQMLN(SiamesePrototypeModel, GLVQ):
class LVQMLN(SiameseGLVQ):
"""Learning Vector Quantization Multi-Layer Network.
GLVQ model that applies an arbitrary transformation on the inputs, BUT NOT
@ -216,17 +227,11 @@ class LVQMLN(SiamesePrototypeModel, GLVQ):
rather in the embedding space.
"""
def __init__(self, hparams, backbone=torch.nn.Identity(), **kwargs):
super().__init__(hparams, **kwargs)
self.backbone = backbone
self.distance_fn = kwargs.get("distance_fn", sed)
def forward(self, x):
def _forward(self, x):
latent_protos, _ = self.proto_layer()
latent_x = self.backbone(x)
dis = self.distance_fn(latent_x, latent_protos)
return dis
distances = self.distance_fn(latent_x, latent_protos)
return distances
class NonGradientGLVQ(GLVQ):
@ -244,7 +249,7 @@ class LVQ1(NonGradientGLVQ):
plabels = self.proto_layer.component_labels
x, y = train_batch
dis = self(x)
dis = self._forward(x)
# TODO Vectorized implementation
for xi, yi in zip(x, y):
@ -272,7 +277,7 @@ class LVQ21(NonGradientGLVQ):
plabels = self.proto_layer.component_labels
x, y = train_batch
dis = self(x)
dis = self._forward(x)
# TODO Vectorized implementation
for xi, yi in zip(x, y):