[WIP] Update CBC implementation to use SiameseGLVQ

This commit is contained in:
Jensun Ravichandran 2021-05-20 17:36:00 +02:00
parent 49f9a12b5f
commit 88a34a06ef
3 changed files with 83 additions and 89 deletions

View File

@ -6,13 +6,10 @@ import torch
if __name__ == "__main__": if __name__ == "__main__":
# Dataset # Dataset
from sklearn.datasets import load_iris train_ds = pt.datasets.Iris(dims=[0, 2])
x_train, y_train = load_iris(return_X_y=True)
x_train = x_train[:, [0, 2]]
train_ds = pt.datasets.NumpyDataset(x_train, y_train)
# Reproducibility # Reproducibility
pl.utilities.seed.seed_everything(seed=2) pl.utilities.seed.seed_everything(seed=3)
# Dataloaders # Dataloaders
train_loader = torch.utils.data.DataLoader(train_ds, train_loader = torch.utils.data.DataLoader(train_ds,
@ -21,18 +18,19 @@ if __name__ == "__main__":
# Hyperparameters # Hyperparameters
hparams = dict( hparams = dict(
input_dim=x_train.shape[1], distribution=[3, 2, 2],
nclasses=3, proto_lr=0.01,
num_components=5, bb_lr=0.01,
component_initializer=pt.components.SSI(train_ds, noise=0.01),
lr=0.01,
) )
# Initialize the model # Initialize the model
model = pt.models.CBC(hparams) model = pt.models.CBC(
hparams,
prototype_initializer=pt.components.SSI(train_ds, noise=0.01),
)
# Callbacks # Callbacks
dvis = pt.models.VisCBC2D(data=(x_train, y_train), dvis = pt.models.VisCBC2D(data=train_ds,
title="CBC Iris Example", title="CBC Iris Example",
resolution=300, resolution=300,
axis_off=True) axis_off=True)

View File

@ -5,6 +5,10 @@ from prototorch.components.components import Components
from prototorch.functions.distances import euclidean_distance from prototorch.functions.distances import euclidean_distance
from prototorch.functions.similarities import cosine_similarity from prototorch.functions.similarities import cosine_similarity
from .abstract import (AbstractPrototypeModel, PrototypeImageModel,
SiamesePrototypeModel)
from .glvq import SiameseGLVQ
def rescaled_cosine_similarity(x, y): def rescaled_cosine_similarity(x, y):
"""Cosine Similarity rescaled to [0, 1].""" """Cosine Similarity rescaled to [0, 1]."""
@ -16,9 +20,9 @@ def shift_activation(x):
return (x + 1.0) / 2.0 return (x + 1.0) / 2.0
def euclidean_similarity(x, y): def euclidean_similarity(x, y, beta=3):
d = euclidean_distance(x, y) d = euclidean_distance(x, y)
return torch.exp(-d * 3) return torch.exp(-d * beta)
class CosineSimilarity(torch.nn.Module): class CosineSimilarity(torch.nn.Module):
@ -55,11 +59,12 @@ class MarginLoss(torch.nn.modules.loss._Loss):
class ReasoningLayer(torch.nn.Module): class ReasoningLayer(torch.nn.Module):
def __init__(self, n_components, n_classes, n_replicas=1): def __init__(self, num_components, num_classes, n_replicas=1):
super().__init__() super().__init__()
self.n_replicas = n_replicas self.n_replicas = n_replicas
self.n_classes = n_classes self.num_classes = num_classes
probabilities_init = torch.zeros(2, 1, n_components, self.n_classes) probabilities_init = torch.zeros(2, 1, num_components,
self.num_classes)
probabilities_init.uniform_(0.4, 0.6) probabilities_init.uniform_(0.4, 0.6)
self.reasoning_probabilities = torch.nn.Parameter(probabilities_init) self.reasoning_probabilities = torch.nn.Parameter(probabilities_init)
@ -81,73 +86,59 @@ class ReasoningLayer(torch.nn.Module):
return probs return probs
class CBC(pl.LightningModule): class CBC(SiameseGLVQ):
"""Classification-By-Components.""" """Classification-By-Components."""
def __init__(self, def __init__(self,
hparams, hparams,
margin=0.1, margin=0.1,
backbone_class=torch.nn.Identity,
similarity=euclidean_similarity, similarity=euclidean_similarity,
**kwargs): **kwargs):
super().__init__() super().__init__(hparams, **kwargs)
self.save_hyperparameters(hparams)
self.margin = margin self.margin = margin
self.component_layer = Components(self.hparams.num_components, self.similarity_fn = kwargs.get("similarity_fn", euclidean_similarity)
self.hparams.component_initializer) num_components = self.components.shape[0]
# self.similarity = CosineSimilarity() self.reasoning_layer = ReasoningLayer(num_components=num_components,
self.similarity = similarity num_classes=self.num_classes)
self.backbone = backbone_class() self.component_layer = self.proto_layer
self.backbone_dependent = backbone_class().requires_grad_(False)
n_components = self.components.shape[0]
self.reasoning_layer = ReasoningLayer(n_components=n_components,
n_classes=self.hparams.nclasses)
self.train_acc = torchmetrics.Accuracy()
@property @property
def components(self): def components(self):
return self.component_layer.components.detach().cpu() return self.prototypes
@property @property
def reasonings(self): def reasonings(self):
return self.reasoning_layer.reasonings.cpu() return self.reasoning_layer.reasonings.cpu()
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.lr)
return optimizer
def sync_backbones(self):
master_state = self.backbone.state_dict()
self.backbone_dependent.load_state_dict(master_state, strict=True)
def forward(self, x): def forward(self, x):
self.sync_backbones() components, _ = self.component_layer()
protos = self.component_layer()
latent_x = self.backbone(x) latent_x = self.backbone(x)
latent_protos = self.backbone_dependent(protos) self.backbone.requires_grad_(self.both_path_gradients)
latent_components = self.backbone(components)
detections = self.similarity(latent_x, latent_protos) self.backbone.requires_grad_(True)
detections = self.similarity_fn(latent_x, latent_components)
probs = self.reasoning_layer(detections) probs = self.reasoning_layer(detections)
return probs return probs
def training_step(self, train_batch, batch_idx): def shared_step(self, batch, batch_idx, optimizer_idx=None):
x, y = train_batch x, y = batch
x = x.view(x.size(0), -1) # x = x.view(x.size(0), -1)
y_pred = self(x) y_pred = self(x)
nclasses = self.reasoning_layer.n_classes nclasses = self.reasoning_layer.num_classes
y_true = torch.nn.functional.one_hot(y.long(), num_classes=nclasses) y_true = torch.nn.functional.one_hot(y.long(), num_classes=nclasses)
loss = MarginLoss(self.margin)(y_pred, y_true).mean(dim=0) loss = MarginLoss(self.margin)(y_pred, y_true).mean(dim=0)
self.log("train_loss", loss) return y_pred, loss
self.train_acc(y_pred, y_true)
self.log( def training_step(self, batch, batch_idx, optimizer_idx=None):
"acc", y_pred, train_loss = self.shared_step(batch, batch_idx, optimizer_idx)
self.train_acc, preds = torch.argmax(y_pred, dim=1)
self.acc_metric(preds.int(), batch[1].int())
self.log("train_acc",
self.acc_metric,
on_step=False, on_step=False,
on_epoch=True, on_epoch=True,
prog_bar=True, prog_bar=True,
logger=True, logger=True)
) return train_loss
return loss
def predict(self, x): def predict(self, x):
with torch.no_grad(): with torch.no_grad():

View File

@ -49,11 +49,21 @@ class GLVQ(AbstractPrototypeModel):
def prototype_labels(self): def prototype_labels(self):
return self.proto_layer.component_labels.detach().cpu() return self.proto_layer.component_labels.detach().cpu()
def forward(self, x): @property
def num_classes(self):
return len(self.proto_layer.distribution)
def _forward(self, x):
protos, _ = self.proto_layer() protos, _ = self.proto_layer()
distances = self.distance_fn(x, protos) distances = self.distance_fn(x, protos)
return distances return distances
def forward(self, x):
distances = self._forward(x)
y_pred = self.predict_from_distances(distances)
y_pred = torch.eye(self.num_classes, device=self.device)[y_pred.int()]
return y_pred
def predict_from_distances(self, distances): def predict_from_distances(self, distances):
with torch.no_grad(): with torch.no_grad():
plabels = self.proto_layer.component_labels plabels = self.proto_layer.component_labels
@ -62,7 +72,7 @@ class GLVQ(AbstractPrototypeModel):
def predict(self, x): def predict(self, x):
with torch.no_grad(): with torch.no_grad():
distances = self(x) distances = self._forward(x)
y_pred = self.predict_from_distances(distances) y_pred = self.predict_from_distances(distances)
return y_pred return y_pred
@ -80,7 +90,7 @@ class GLVQ(AbstractPrototypeModel):
def shared_step(self, batch, batch_idx, optimizer_idx=None): def shared_step(self, batch, batch_idx, optimizer_idx=None):
x, y = batch x, y = batch
out = self(x) out = self._forward(x)
plabels = self.proto_layer.component_labels plabels = self.proto_layer.component_labels
mu = self.loss(out, y, prototype_labels=plabels) mu = self.loss(out, y, prototype_labels=plabels)
batch_loss = self.transfer_fn(mu, beta=self.hparams.transfer_beta) batch_loss = self.transfer_fn(mu, beta=self.hparams.transfer_beta)
@ -89,6 +99,7 @@ class GLVQ(AbstractPrototypeModel):
def training_step(self, batch, batch_idx, optimizer_idx=None): def training_step(self, batch, batch_idx, optimizer_idx=None):
out, train_loss = self.shared_step(batch, batch_idx, optimizer_idx) out, train_loss = self.shared_step(batch, batch_idx, optimizer_idx)
self.log("train_loss", train_loss)
self.log_acc(out, batch[-1], tag="train_acc") self.log_acc(out, batch[-1], tag="train_acc")
return train_loss return train_loss
@ -137,23 +148,22 @@ class SiameseGLVQ(SiamesePrototypeModel, GLVQ):
self.both_path_gradients = both_path_gradients self.both_path_gradients = both_path_gradients
self.distance_fn = kwargs.get("distance_fn", sed) self.distance_fn = kwargs.get("distance_fn", sed)
def forward(self, x): def _forward(self, x):
protos, _ = self.proto_layer() protos, _ = self.proto_layer()
latent_x = self.backbone(x) latent_x = self.backbone(x)
self.backbone.requires_grad_(self.both_path_gradients) self.backbone.requires_grad_(self.both_path_gradients)
latent_protos = self.backbone(protos) latent_protos = self.backbone(protos)
self.backbone.requires_grad_(True) self.backbone.requires_grad_(True)
dis = self.distance_fn(latent_x, latent_protos) distances = self.distance_fn(latent_x, latent_protos)
return dis return distances
class GRLVQ(SiamesePrototypeModel, GLVQ): class GRLVQ(SiameseGLVQ):
"""Generalized Relevance Learning Vector Quantization.""" """Generalized Relevance Learning Vector Quantization."""
def __init__(self, hparams, **kwargs): def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs) super().__init__(hparams, **kwargs)
self.relevances = torch.nn.parameter.Parameter( self.relevances = torch.nn.parameter.Parameter(
torch.ones(self.hparams.input_dim)) torch.ones(self.hparams.input_dim))
self.distance_fn = kwargs.get("distance_fn", sed)
@property @property
def relevance_profile(self): def relevance_profile(self):
@ -163,20 +173,19 @@ class GRLVQ(SiamesePrototypeModel, GLVQ):
"""Namespace hook for the visualization callbacks to work.""" """Namespace hook for the visualization callbacks to work."""
return x @ torch.diag(self.relevances) return x @ torch.diag(self.relevances)
def forward(self, x): def _forward(self, x):
protos, _ = self.proto_layer() protos, _ = self.proto_layer()
dis = omega_distance(x, protos, torch.diag(self.relevances)) distances = omega_distance(x, protos, torch.diag(self.relevances))
return dis return distances
class GMLVQ(SiamesePrototypeModel, GLVQ): class GMLVQ(SiameseGLVQ):
"""Generalized Matrix Learning Vector Quantization.""" """Generalized Matrix Learning Vector Quantization."""
def __init__(self, hparams, **kwargs): def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs) super().__init__(hparams, **kwargs)
self.backbone = torch.nn.Linear(self.hparams.input_dim, self.backbone = torch.nn.Linear(self.hparams.input_dim,
self.hparams.latent_dim, self.hparams.latent_dim,
bias=False) bias=False)
self.distance_fn = kwargs.get("distance_fn", sed)
@property @property
def omega_matrix(self): def omega_matrix(self):
@ -198,16 +207,18 @@ class GMLVQ(SiamesePrototypeModel, GLVQ):
plt.colorbar() plt.colorbar()
plt.show(block=True) plt.show(block=True)
def forward(self, x): def _forward(self, x):
protos, _ = self.proto_layer() protos, _ = self.proto_layer()
x, protos = get_flat(x, protos) x, protos = get_flat(x, protos)
latent_x = self.backbone(x) latent_x = self.backbone(x)
self.backbone.requires_grad_(self.both_path_gradients)
latent_protos = self.backbone(protos) latent_protos = self.backbone(protos)
dis = self.distance_fn(latent_x, latent_protos) self.backbone.requires_grad_(True)
return dis distances = self.distance_fn(latent_x, latent_protos)
return distances
class LVQMLN(SiamesePrototypeModel, GLVQ): class LVQMLN(SiameseGLVQ):
"""Learning Vector Quantization Multi-Layer Network. """Learning Vector Quantization Multi-Layer Network.
GLVQ model that applies an arbitrary transformation on the inputs, BUT NOT GLVQ model that applies an arbitrary transformation on the inputs, BUT NOT
@ -216,17 +227,11 @@ class LVQMLN(SiamesePrototypeModel, GLVQ):
rather in the embedding space. rather in the embedding space.
""" """
def __init__(self, hparams, backbone=torch.nn.Identity(), **kwargs): def _forward(self, x):
super().__init__(hparams, **kwargs)
self.backbone = backbone
self.distance_fn = kwargs.get("distance_fn", sed)
def forward(self, x):
latent_protos, _ = self.proto_layer() latent_protos, _ = self.proto_layer()
latent_x = self.backbone(x) latent_x = self.backbone(x)
dis = self.distance_fn(latent_x, latent_protos) distances = self.distance_fn(latent_x, latent_protos)
return dis return distances
class NonGradientGLVQ(GLVQ): class NonGradientGLVQ(GLVQ):
@ -244,7 +249,7 @@ class LVQ1(NonGradientGLVQ):
plabels = self.proto_layer.component_labels plabels = self.proto_layer.component_labels
x, y = train_batch x, y = train_batch
dis = self(x) dis = self._forward(x)
# TODO Vectorized implementation # TODO Vectorized implementation
for xi, yi in zip(x, y): for xi, yi in zip(x, y):
@ -272,7 +277,7 @@ class LVQ21(NonGradientGLVQ):
plabels = self.proto_layer.component_labels plabels = self.proto_layer.component_labels
x, y = train_batch x, y = train_batch
dis = self(x) dis = self._forward(x)
# TODO Vectorized implementation # TODO Vectorized implementation
for xi, yi in zip(x, y): for xi, yi in zip(x, y):