feat: gtlvq with examples

This commit is contained in:
Christoph 2021-11-10 17:04:24 +00:00 committed by Jensun Ravichandran
parent 6ffd27d12a
commit d3bb430104
4 changed files with 194 additions and 0 deletions

130
examples/gtlvq_moons.py Normal file
View File

@ -0,0 +1,130 @@
"""Localized-GMLVQ example using the Moons dataset."""
import argparse
import prototorch as pt
import pytorch_lightning as pl
import torch
if __name__ == "__main__":
# Command-line arguments
parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser)
args = parser.parse_args()
# Reproducibility
pl.utilities.seed.seed_everything(seed=2)
# Dataset
train_ds = pt.datasets.Moons(num_samples=300, noise=0.2, seed=42)
# Dataloaders
train_loader = torch.utils.data.DataLoader(train_ds,
batch_size=256,
shuffle=True)
# Hyperparameters
hparams = dict(distribution=[1, 3], input_dim=2, latent_dim=2)
# Initialize the model
model = pt.models.GTLVQ(
hparams,
prototypes_initializer=pt.initializers.SMCI(train_ds),
omega_initializer=-pt.initializers.PCALinearTransformInitializer(
train_ds))
# Compute intermediate input and output sizes
model.example_input_array = torch.zeros(4, 2)
# Summary
print(model)
# Callbacks
vis = pt.models.VisGLVQ2D(data=train_ds)
es = pl.callbacks.EarlyStopping(
monitor="train_acc",
min_delta=0.001,
patience=20,
mode="max",
verbose=False,
check_on_train_epoch_end=True,
)
# Setup trainer
trainer = pl.Trainer.from_argparse_args(
args,
callbacks=[
vis,
es,
],
weights_summary="full",
accelerator="ddp",
)
# Training loop
trainer.fit(model, train_loader)
"""Localized-GMLVQ example using the Moons dataset."""
import argparse
import prototorch as pt
import pytorch_lightning as pl
import torch
if __name__ == "__main__":
# Command-line arguments
parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser)
args = parser.parse_args()
# Reproducibility
pl.utilities.seed.seed_everything(seed=2)
# Dataset
train_ds = pt.datasets.Moons(num_samples=300, noise=0.2, seed=42)
# Dataloaders
train_loader = torch.utils.data.DataLoader(train_ds,
batch_size=256,
shuffle=True)
# Hyperparameters
hparams = dict(distribution=[1, 3], input_dim=2, latent_dim=2)
# Initialize the model
model = pt.models.GTLVQ(
hparams,
prototypes_initializer=pt.initializers.SMCI(train_ds),
omega_initializer=-pt.initializers.PCALinearTransformInitializer(
train_ds))
# Compute intermediate input and output sizes
model.example_input_array = torch.zeros(4, 2)
# Summary
print(model)
# Callbacks
vis = pt.models.VisGLVQ2D(data=train_ds)
es = pl.callbacks.EarlyStopping(
monitor="train_acc",
min_delta=0.001,
patience=20,
mode="max",
verbose=False,
check_on_train_epoch_end=True,
)
# Setup trainer
trainer = pl.Trainer.from_argparse_args(
args,
callbacks=[
vis,
es,
],
weights_summary="full",
accelerator="ddp",
)
# Training loop
trainer.fit(model, train_loader)

View File

@ -8,6 +8,7 @@ from .glvq import (
GLVQ21, GLVQ21,
GMLVQ, GMLVQ,
GRLVQ, GRLVQ,
GTLVQ,
LGMLVQ, LGMLVQ,
LVQMLN, LVQMLN,
ImageGLVQ, ImageGLVQ,

View File

@ -15,6 +15,44 @@ def rank_scaled_gaussian(distances, lambd):
return torch.exp(-torch.exp(-ranks / lambd) * distances) return torch.exp(-torch.exp(-ranks / lambd) * distances)
def orthogonalization(tensors):
"""Orthogonalization via polar decomposition """
u, _, v = torch.svd(tensors, compute_uv=True)
u_shape = tuple(list(u.shape))
v_shape = tuple(list(v.shape))
# reshape to (num x N x M)
u = torch.reshape(u, (-1, u_shape[-2], u_shape[-1]))
v = torch.reshape(v, (-1, v_shape[-2], v_shape[-1]))
out = u @ v.permute([0, 2, 1])
out = torch.reshape(out, u_shape[:-1] + (v_shape[-2], ))
return out
def ltangent_distance(x, y, omegas):
r"""Localized Tangent distance.
Compute Orthogonal Complement: math:`\bm P_k = \bm I - \Omega_k \Omega_k^T`
Compute Tangent Distance: math:`{\| \bm P \bm x - \bm P_k \bm y_k \|}_2`
:param `torch.tensor` omegas: Three dimensional matrix
:rtype: `torch.tensor`
"""
x, y = [arr.view(arr.size(0), -1) for arr in (x, y)]
p = torch.eye(omegas.shape[-2], device=omegas.device) - torch.bmm(
omegas, omegas.permute([0, 2, 1]))
projected_x = x @ p
projected_y = torch.diagonal(y @ p).T
expanded_y = torch.unsqueeze(projected_y, dim=1)
batchwise_difference = expanded_y - projected_x
differences_squared = batchwise_difference**2
distances = torch.sqrt(torch.sum(differences_squared, dim=2))
distances = distances.permute(1, 0)
return distances
class GaussianPrior(torch.nn.Module): class GaussianPrior(torch.nn.Module):
def __init__(self, variance): def __init__(self, variance):
super().__init__() super().__init__()

View File

@ -10,6 +10,7 @@ from ..core.losses import GLVQLoss, lvq1_loss, lvq21_loss
from ..core.transforms import LinearTransform from ..core.transforms import LinearTransform
from ..nn.wrappers import LambdaLayer, LossLayer from ..nn.wrappers import LambdaLayer, LossLayer
from .abstract import ImagePrototypesMixin, SupervisedPrototypeModel from .abstract import ImagePrototypesMixin, SupervisedPrototypeModel
from .extras import ltangent_distance, orthogonalization
class GLVQ(SupervisedPrototypeModel): class GLVQ(SupervisedPrototypeModel):
@ -282,6 +283,30 @@ class LGMLVQ(GMLVQ):
self.register_parameter("_omega", Parameter(omega)) self.register_parameter("_omega", Parameter(omega))
class GTLVQ(LGMLVQ):
"""Localized and Generalized Matrix Learning Vector Quantization."""
def __init__(self, hparams, **kwargs):
distance_fn = kwargs.pop("distance_fn", ltangent_distance)
super().__init__(hparams, distance_fn=distance_fn, **kwargs)
omega_initializer = kwargs.get("omega_initializer")
omega = omega_initializer.generate(self.hparams.input_dim,
self.hparams.latent_dim)
# Re-register `_omega` to override the one from the super class.
omega = torch.rand(
self.num_prototypes,
self.hparams.input_dim,
self.hparams.latent_dim,
device=self.device,
)
self.register_parameter("_omega", Parameter(omega))
def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
with torch.no_grad():
self._omega.copy_(orthogonalization(self._omega))
class GLVQ1(GLVQ): class GLVQ1(GLVQ):
"""Generalized Learning Vector Quantization 1.""" """Generalized Learning Vector Quantization 1."""
def __init__(self, hparams, **kwargs): def __init__(self, hparams, **kwargs):