Added Vis for GMLVQ with more then 2 dims using PCA (#11)
* Added Vis for GMLVQ with more then 2 dims using PCA * Added initialization possibility to GMlVQ with PCA and one example with omega init + PCA vis of 3 dims * test(githooks): Add githooks for automatic commit checks Co-authored-by: staps@hs-mittweida.de <staps@hs-mittweida.de> Co-authored-by: Alexander Engelsberger <alexanderengelsberger@gmail.com>
This commit is contained in:
parent
8956ee75ad
commit
0a2da9ae50
59
examples/gmlvq_iris.py
Normal file
59
examples/gmlvq_iris.py
Normal file
@ -0,0 +1,59 @@
|
|||||||
|
"""GLVQ example using the Iris dataset."""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
import prototorch as pt
|
||||||
|
import pytorch_lightning as pl
|
||||||
|
import torch
|
||||||
|
from torch.optim.lr_scheduler import ExponentialLR
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# Command-line arguments
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser = pl.Trainer.add_argparse_args(parser)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Dataset
|
||||||
|
train_ds = pt.datasets.Iris()
|
||||||
|
|
||||||
|
# Dataloaders
|
||||||
|
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=64)
|
||||||
|
|
||||||
|
# Hyperparameters
|
||||||
|
hparams = dict(
|
||||||
|
input_dim=4,
|
||||||
|
latent_dim=3,
|
||||||
|
distribution={
|
||||||
|
"num_classes": 3,
|
||||||
|
"prototypes_per_class": 2
|
||||||
|
},
|
||||||
|
proto_lr=0.0005,
|
||||||
|
bb_lr=0.0005,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize the model
|
||||||
|
model = pt.models.GMLVQ(
|
||||||
|
hparams,
|
||||||
|
optimizer=torch.optim.Adam,
|
||||||
|
prototype_initializer=pt.components.SSI(train_ds),
|
||||||
|
lr_scheduler=ExponentialLR,
|
||||||
|
lr_scheduler_kwargs=dict(gamma=0.99, verbose=False),
|
||||||
|
omega_initializer=pt.components.PCA(train_ds.data)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Compute intermediate input and output sizes
|
||||||
|
#model.example_input_array = torch.zeros(4, 2)
|
||||||
|
|
||||||
|
# Callbacks
|
||||||
|
vis = pt.models.VisGMLVQ2D(data=train_ds, border=0.1)
|
||||||
|
|
||||||
|
# Setup trainer
|
||||||
|
trainer = pl.Trainer.from_argparse_args(
|
||||||
|
args,
|
||||||
|
callbacks=[vis],
|
||||||
|
weights_summary="full",
|
||||||
|
accelerator="ddp",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Training loop
|
||||||
|
trainer.fit(model, train_loader)
|
BIN
prototorch/models/.glvq.py.swp
Normal file
BIN
prototorch/models/.glvq.py.swp
Normal file
Binary file not shown.
@ -7,6 +7,7 @@ from prototorch.functions.distances import (lomega_distance, omega_distance,
|
|||||||
squared_euclidean_distance)
|
squared_euclidean_distance)
|
||||||
from prototorch.functions.helper import get_flat
|
from prototorch.functions.helper import get_flat
|
||||||
from prototorch.functions.losses import glvq_loss, lvq1_loss, lvq21_loss
|
from prototorch.functions.losses import glvq_loss, lvq1_loss, lvq21_loss
|
||||||
|
from prototorch.components import LinearMapping
|
||||||
from prototorch.modules import LambdaLayer, LossLayer
|
from prototorch.modules import LambdaLayer, LossLayer
|
||||||
from torch.nn.parameter import Parameter
|
from torch.nn.parameter import Parameter
|
||||||
|
|
||||||
@ -239,10 +240,17 @@ class GMLVQ(GLVQ):
|
|||||||
super().__init__(hparams, distance_fn=distance_fn, **kwargs)
|
super().__init__(hparams, distance_fn=distance_fn, **kwargs)
|
||||||
|
|
||||||
# Additional parameters
|
# Additional parameters
|
||||||
omega = torch.randn(self.hparams.input_dim,
|
omega_initializer = kwargs.get("omega_initializer", None)
|
||||||
self.hparams.latent_dim,
|
initialized_omega = kwargs.get("initialized_omega", None)
|
||||||
device=self.device)
|
if omega_initializer is not None or initialized_omega is not None:
|
||||||
self.register_parameter("_omega", Parameter(omega))
|
self.omega_layer = LinearMapping(
|
||||||
|
mapping_shape=(self.hparams.input_dim, self.hparams.latent_dim),
|
||||||
|
initializer=omega_initializer,
|
||||||
|
initialized_linearmapping=initialized_omega,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.register_parameter("_omega", Parameter(self.omega_layer.mapping))
|
||||||
|
self.backbone = LambdaLayer(lambda x: x @ self._omega, name = "omega matrix")
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def omega_matrix(self):
|
def omega_matrix(self):
|
||||||
@ -256,6 +264,24 @@ class GMLVQ(GLVQ):
|
|||||||
def extra_repr(self):
|
def extra_repr(self):
|
||||||
return f"(omega): (shape: {tuple(self._omega.shape)})"
|
return f"(omega): (shape: {tuple(self._omega.shape)})"
|
||||||
|
|
||||||
|
def predict_latent(self, x, map_protos=True):
|
||||||
|
"""Predict `x` assuming it is already embedded in the latent space.
|
||||||
|
|
||||||
|
Only the prototypes are embedded in the latent space using the
|
||||||
|
backbone.
|
||||||
|
|
||||||
|
"""
|
||||||
|
self.eval()
|
||||||
|
with torch.no_grad():
|
||||||
|
protos, plabels = self.proto_layer()
|
||||||
|
if map_protos:
|
||||||
|
protos = self.backbone(protos)
|
||||||
|
d = squared_euclidean_distance(x, protos)
|
||||||
|
y_pred = wtac(d, plabels)
|
||||||
|
return y_pred
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class LGMLVQ(GMLVQ):
|
class LGMLVQ(GMLVQ):
|
||||||
"""Localized and Generalized Matrix Learning Vector Quantization."""
|
"""Localized and Generalized Matrix Learning Vector Quantization."""
|
||||||
|
@ -83,7 +83,13 @@ class Vis2DAbstract(pl.Callback):
|
|||||||
mesh_input = np.c_[xx.ravel(), yy.ravel()]
|
mesh_input = np.c_[xx.ravel(), yy.ravel()]
|
||||||
return mesh_input, xx, yy
|
return mesh_input, xx, yy
|
||||||
|
|
||||||
def plot_data(self, ax, x, y):
|
def perform_pca_2D(self, data):
|
||||||
|
(_, eigVal, eigVec) = torch.pca_lowrank(data, q=2)
|
||||||
|
return data @ eigVec
|
||||||
|
|
||||||
|
def plot_data(self, ax, x, y, pca=False):
|
||||||
|
if pca:
|
||||||
|
x = self.perform_pca_2D(x)
|
||||||
ax.scatter(
|
ax.scatter(
|
||||||
x[:, 0],
|
x[:, 0],
|
||||||
x[:, 1],
|
x[:, 1],
|
||||||
@ -94,7 +100,9 @@ class Vis2DAbstract(pl.Callback):
|
|||||||
s=30,
|
s=30,
|
||||||
)
|
)
|
||||||
|
|
||||||
def plot_protos(self, ax, protos, plabels):
|
def plot_protos(self, ax, protos, plabels, pca=False):
|
||||||
|
if pca:
|
||||||
|
protos = self.perform_pca_2D(protos)
|
||||||
ax.scatter(
|
ax.scatter(
|
||||||
protos[:, 0],
|
protos[:, 0],
|
||||||
protos[:, 1],
|
protos[:, 1],
|
||||||
@ -186,6 +194,50 @@ class VisSiameseGLVQ2D(Vis2DAbstract):
|
|||||||
self.log_and_display(trainer, pl_module)
|
self.log_and_display(trainer, pl_module)
|
||||||
|
|
||||||
|
|
||||||
|
class VisGMLVQ2D(Vis2DAbstract):
|
||||||
|
def __init__(self, *args, map_protos=True, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.map_protos = map_protos
|
||||||
|
|
||||||
|
def on_epoch_end(self, trainer, pl_module):
|
||||||
|
if not self.precheck(trainer):
|
||||||
|
return True
|
||||||
|
|
||||||
|
protos = pl_module.prototypes
|
||||||
|
plabels = pl_module.prototype_labels
|
||||||
|
x_train, y_train = self.x_train, self.y_train
|
||||||
|
device = pl_module.device
|
||||||
|
with torch.no_grad():
|
||||||
|
x_train = pl_module.backbone(torch.Tensor(x_train).to(device))
|
||||||
|
x_train = x_train.cpu().detach()
|
||||||
|
if self.map_protos:
|
||||||
|
with torch.no_grad():
|
||||||
|
protos = pl_module.backbone(torch.Tensor(protos).to(device))
|
||||||
|
protos = protos.cpu().detach()
|
||||||
|
ax = self.setup_ax()
|
||||||
|
if x_train.shape[1] > 2:
|
||||||
|
self.plot_data(ax, x_train, y_train, pca=True)
|
||||||
|
else:
|
||||||
|
self.plot_data(ax, x_train, y_train, pca=False)
|
||||||
|
if self.show_protos:
|
||||||
|
if protos.shape[1] > 2:
|
||||||
|
self.plot_protos(ax, protos, plabels, pca=True)
|
||||||
|
else:
|
||||||
|
self.plot_protos(ax, protos, plabels, pca=False)
|
||||||
|
### something to work on: meshgrid with pca
|
||||||
|
# x = np.vstack((x_train, protos))
|
||||||
|
# mesh_input, xx, yy = self.get_mesh_input(x)
|
||||||
|
#else:
|
||||||
|
# mesh_input, xx, yy = self.get_mesh_input(x_train)
|
||||||
|
#_components = pl_module.proto_layer._components
|
||||||
|
#mesh_input = torch.Tensor(mesh_input).type_as(_components)
|
||||||
|
#y_pred = pl_module.predict_latent(mesh_input,
|
||||||
|
# map_protos=self.map_protos)
|
||||||
|
#y_pred = y_pred.cpu().reshape(xx.shape)
|
||||||
|
#ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
|
||||||
|
self.log_and_display(trainer, pl_module)
|
||||||
|
|
||||||
|
|
||||||
class VisCBC2D(Vis2DAbstract):
|
class VisCBC2D(Vis2DAbstract):
|
||||||
def on_epoch_end(self, trainer, pl_module):
|
def on_epoch_end(self, trainer, pl_module):
|
||||||
if not self.precheck(trainer):
|
if not self.precheck(trainer):
|
||||||
|
Loading…
Reference in New Issue
Block a user