feat(vis): 2D EV projection for GMLVQ

This commit is contained in:
Jensun Ravichandran 2021-09-01 10:49:57 +02:00
parent 7d4a041df2
commit fa928afe2c
No known key found for this signature in database
GPG Key ID: 4E9348239810B51F
3 changed files with 97 additions and 0 deletions

58
examples/gmlvq_iris.py Normal file
View File

@ -0,0 +1,58 @@
"""GMLVQ example using the Iris dataset."""
import argparse
import prototorch as pt
import pytorch_lightning as pl
import torch
from torch.optim.lr_scheduler import ExponentialLR
if __name__ == "__main__":
# Command-line arguments
parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser)
args = parser.parse_args()
# Dataset
train_ds = pt.datasets.Iris()
# Dataloaders
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=64)
# Hyperparameters
hparams = dict(
input_dim=4,
latent_dim=4,
distribution={
"num_classes": 3,
"per_class": 2
},
proto_lr=0.01,
bb_lr=0.01,
)
# Initialize the model
model = pt.models.GMLVQ(
hparams,
optimizer=torch.optim.Adam,
prototypes_initializer=pt.initializers.SMCI(train_ds),
lr_scheduler=ExponentialLR,
lr_scheduler_kwargs=dict(gamma=0.99, verbose=False),
)
# Compute intermediate input and output sizes
model.example_input_array = torch.zeros(4, 4)
# Callbacks
vis = pt.models.VisGMLVQ2D(data=train_ds)
# Setup trainer
trainer = pl.Trainer.from_argparse_args(
args,
callbacks=[vis],
weights_summary="full",
accelerator="ddp",
)
# Training loop
trainer.fit(model, train_loader)

View File

@ -251,6 +251,12 @@ class GMLVQ(GLVQ):
def omega_matrix(self): def omega_matrix(self):
return self._omega.detach().cpu() return self._omega.detach().cpu()
@property
def lambda_matrix(self):
omega = self._omega.detach() # (input_dim, latent_dim)
lam = omega @ omega.T
return lam.detach().cpu()
def compute_distances(self, x): def compute_distances(self, x):
protos, _ = self.proto_layer() protos, _ = self.proto_layer()
distances = self.distance_layer(x, protos, self._omega) distances = self.distance_layer(x, protos, self._omega)

View File

@ -178,6 +178,39 @@ class VisSiameseGLVQ2D(Vis2DAbstract):
self.log_and_display(trainer, pl_module) self.log_and_display(trainer, pl_module)
class VisGMLVQ2D(Vis2DAbstract):
def __init__(self, *args, ev_proj=True, **kwargs):
super().__init__(*args, **kwargs)
self.ev_proj = ev_proj
def on_epoch_end(self, trainer, pl_module):
if not self.precheck(trainer):
return True
protos = pl_module.prototypes
plabels = pl_module.prototype_labels
x_train, y_train = self.x_train, self.y_train
device = pl_module.device
omega = pl_module._omega.detach()
lam = omega @ omega.T
u, _, _ = torch.pca_lowrank(lam, q=2)
with torch.no_grad():
x_train = torch.Tensor(x_train).to(device)
x_train = x_train @ u
x_train = x_train.cpu().detach()
if self.show_protos:
with torch.no_grad():
protos = torch.Tensor(protos).to(device)
protos = protos @ u
protos = protos.cpu().detach()
ax = self.setup_ax()
self.plot_data(ax, x_train, y_train)
if self.show_protos:
self.plot_protos(ax, protos, plabels)
self.log_and_display(trainer, pl_module)
class VisCBC2D(Vis2DAbstract): class VisCBC2D(Vis2DAbstract):
def on_epoch_end(self, trainer, pl_module): def on_epoch_end(self, trainer, pl_module):
if not self.precheck(trainer): if not self.precheck(trainer):