feat: gtlvq with examples

This commit is contained in:
Christoph 2021-11-10 17:04:24 +00:00 committed by Jensun Ravichandran
parent 6ffd27d12a
commit d3bb430104
4 changed files with 194 additions and 0 deletions

130
examples/gtlvq_moons.py Normal file
View File

@ -0,0 +1,130 @@
"""Localized-GMLVQ example using the Moons dataset."""
import argparse
import prototorch as pt
import pytorch_lightning as pl
import torch
if __name__ == "__main__":
# Command-line arguments
parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser)
args = parser.parse_args()
# Reproducibility
pl.utilities.seed.seed_everything(seed=2)
# Dataset
train_ds = pt.datasets.Moons(num_samples=300, noise=0.2, seed=42)
# Dataloaders
train_loader = torch.utils.data.DataLoader(train_ds,
batch_size=256,
shuffle=True)
# Hyperparameters
hparams = dict(distribution=[1, 3], input_dim=2, latent_dim=2)
# Initialize the model
model = pt.models.GTLVQ(
hparams,
prototypes_initializer=pt.initializers.SMCI(train_ds),
omega_initializer=-pt.initializers.PCALinearTransformInitializer(
train_ds))
# Compute intermediate input and output sizes
model.example_input_array = torch.zeros(4, 2)
# Summary
print(model)
# Callbacks
vis = pt.models.VisGLVQ2D(data=train_ds)
es = pl.callbacks.EarlyStopping(
monitor="train_acc",
min_delta=0.001,
patience=20,
mode="max",
verbose=False,
check_on_train_epoch_end=True,
)
# Setup trainer
trainer = pl.Trainer.from_argparse_args(
args,
callbacks=[
vis,
es,
],
weights_summary="full",
accelerator="ddp",
)
# Training loop
trainer.fit(model, train_loader)
"""Localized-GMLVQ example using the Moons dataset."""
import argparse
import prototorch as pt
import pytorch_lightning as pl
import torch
if __name__ == "__main__":
# Command-line arguments
parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser)
args = parser.parse_args()
# Reproducibility
pl.utilities.seed.seed_everything(seed=2)
# Dataset
train_ds = pt.datasets.Moons(num_samples=300, noise=0.2, seed=42)
# Dataloaders
train_loader = torch.utils.data.DataLoader(train_ds,
batch_size=256,
shuffle=True)
# Hyperparameters
hparams = dict(distribution=[1, 3], input_dim=2, latent_dim=2)
# Initialize the model
model = pt.models.GTLVQ(
hparams,
prototypes_initializer=pt.initializers.SMCI(train_ds),
omega_initializer=-pt.initializers.PCALinearTransformInitializer(
train_ds))
# Compute intermediate input and output sizes
model.example_input_array = torch.zeros(4, 2)
# Summary
print(model)
# Callbacks
vis = pt.models.VisGLVQ2D(data=train_ds)
es = pl.callbacks.EarlyStopping(
monitor="train_acc",
min_delta=0.001,
patience=20,
mode="max",
verbose=False,
check_on_train_epoch_end=True,
)
# Setup trainer
trainer = pl.Trainer.from_argparse_args(
args,
callbacks=[
vis,
es,
],
weights_summary="full",
accelerator="ddp",
)
# Training loop
trainer.fit(model, train_loader)

View File

@ -8,6 +8,7 @@ from .glvq import (
GLVQ21,
GMLVQ,
GRLVQ,
GTLVQ,
LGMLVQ,
LVQMLN,
ImageGLVQ,

View File

@ -15,6 +15,44 @@ def rank_scaled_gaussian(distances, lambd):
return torch.exp(-torch.exp(-ranks / lambd) * distances)
def orthogonalization(tensors):
"""Orthogonalization via polar decomposition """
u, _, v = torch.svd(tensors, compute_uv=True)
u_shape = tuple(list(u.shape))
v_shape = tuple(list(v.shape))
# reshape to (num x N x M)
u = torch.reshape(u, (-1, u_shape[-2], u_shape[-1]))
v = torch.reshape(v, (-1, v_shape[-2], v_shape[-1]))
out = u @ v.permute([0, 2, 1])
out = torch.reshape(out, u_shape[:-1] + (v_shape[-2], ))
return out
def ltangent_distance(x, y, omegas):
r"""Localized Tangent distance.
Compute Orthogonal Complement: math:`\bm P_k = \bm I - \Omega_k \Omega_k^T`
Compute Tangent Distance: math:`{\| \bm P \bm x - \bm P_k \bm y_k \|}_2`
:param `torch.tensor` omegas: Three dimensional matrix
:rtype: `torch.tensor`
"""
x, y = [arr.view(arr.size(0), -1) for arr in (x, y)]
p = torch.eye(omegas.shape[-2], device=omegas.device) - torch.bmm(
omegas, omegas.permute([0, 2, 1]))
projected_x = x @ p
projected_y = torch.diagonal(y @ p).T
expanded_y = torch.unsqueeze(projected_y, dim=1)
batchwise_difference = expanded_y - projected_x
differences_squared = batchwise_difference**2
distances = torch.sqrt(torch.sum(differences_squared, dim=2))
distances = distances.permute(1, 0)
return distances
class GaussianPrior(torch.nn.Module):
def __init__(self, variance):
super().__init__()

View File

@ -10,6 +10,7 @@ from ..core.losses import GLVQLoss, lvq1_loss, lvq21_loss
from ..core.transforms import LinearTransform
from ..nn.wrappers import LambdaLayer, LossLayer
from .abstract import ImagePrototypesMixin, SupervisedPrototypeModel
from .extras import ltangent_distance, orthogonalization
class GLVQ(SupervisedPrototypeModel):
@ -282,6 +283,30 @@ class LGMLVQ(GMLVQ):
self.register_parameter("_omega", Parameter(omega))
class GTLVQ(LGMLVQ):
"""Localized and Generalized Matrix Learning Vector Quantization."""
def __init__(self, hparams, **kwargs):
distance_fn = kwargs.pop("distance_fn", ltangent_distance)
super().__init__(hparams, distance_fn=distance_fn, **kwargs)
omega_initializer = kwargs.get("omega_initializer")
omega = omega_initializer.generate(self.hparams.input_dim,
self.hparams.latent_dim)
# Re-register `_omega` to override the one from the super class.
omega = torch.rand(
self.num_prototypes,
self.hparams.input_dim,
self.hparams.latent_dim,
device=self.device,
)
self.register_parameter("_omega", Parameter(omega))
def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
with torch.no_grad():
self._omega.copy_(orthogonalization(self._omega))
class GLVQ1(GLVQ):
"""Generalized Learning Vector Quantization 1."""
def __init__(self, hparams, **kwargs):