feat: gtlvq with examples
This commit is contained in:
parent
6ffd27d12a
commit
d3bb430104
130
examples/gtlvq_moons.py
Normal file
130
examples/gtlvq_moons.py
Normal file
@ -0,0 +1,130 @@
|
|||||||
|
"""Localized-GMLVQ example using the Moons dataset."""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
import prototorch as pt
|
||||||
|
import pytorch_lightning as pl
|
||||||
|
import torch
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# Command-line arguments
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser = pl.Trainer.add_argparse_args(parser)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Reproducibility
|
||||||
|
pl.utilities.seed.seed_everything(seed=2)
|
||||||
|
|
||||||
|
# Dataset
|
||||||
|
train_ds = pt.datasets.Moons(num_samples=300, noise=0.2, seed=42)
|
||||||
|
|
||||||
|
# Dataloaders
|
||||||
|
train_loader = torch.utils.data.DataLoader(train_ds,
|
||||||
|
batch_size=256,
|
||||||
|
shuffle=True)
|
||||||
|
|
||||||
|
# Hyperparameters
|
||||||
|
hparams = dict(distribution=[1, 3], input_dim=2, latent_dim=2)
|
||||||
|
|
||||||
|
# Initialize the model
|
||||||
|
model = pt.models.GTLVQ(
|
||||||
|
hparams,
|
||||||
|
prototypes_initializer=pt.initializers.SMCI(train_ds),
|
||||||
|
omega_initializer=-pt.initializers.PCALinearTransformInitializer(
|
||||||
|
train_ds))
|
||||||
|
|
||||||
|
# Compute intermediate input and output sizes
|
||||||
|
model.example_input_array = torch.zeros(4, 2)
|
||||||
|
|
||||||
|
# Summary
|
||||||
|
print(model)
|
||||||
|
|
||||||
|
# Callbacks
|
||||||
|
vis = pt.models.VisGLVQ2D(data=train_ds)
|
||||||
|
es = pl.callbacks.EarlyStopping(
|
||||||
|
monitor="train_acc",
|
||||||
|
min_delta=0.001,
|
||||||
|
patience=20,
|
||||||
|
mode="max",
|
||||||
|
verbose=False,
|
||||||
|
check_on_train_epoch_end=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Setup trainer
|
||||||
|
trainer = pl.Trainer.from_argparse_args(
|
||||||
|
args,
|
||||||
|
callbacks=[
|
||||||
|
vis,
|
||||||
|
es,
|
||||||
|
],
|
||||||
|
weights_summary="full",
|
||||||
|
accelerator="ddp",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Training loop
|
||||||
|
trainer.fit(model, train_loader)
|
||||||
|
"""Localized-GMLVQ example using the Moons dataset."""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
import prototorch as pt
|
||||||
|
import pytorch_lightning as pl
|
||||||
|
import torch
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# Command-line arguments
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser = pl.Trainer.add_argparse_args(parser)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Reproducibility
|
||||||
|
pl.utilities.seed.seed_everything(seed=2)
|
||||||
|
|
||||||
|
# Dataset
|
||||||
|
train_ds = pt.datasets.Moons(num_samples=300, noise=0.2, seed=42)
|
||||||
|
|
||||||
|
# Dataloaders
|
||||||
|
train_loader = torch.utils.data.DataLoader(train_ds,
|
||||||
|
batch_size=256,
|
||||||
|
shuffle=True)
|
||||||
|
|
||||||
|
# Hyperparameters
|
||||||
|
hparams = dict(distribution=[1, 3], input_dim=2, latent_dim=2)
|
||||||
|
|
||||||
|
# Initialize the model
|
||||||
|
model = pt.models.GTLVQ(
|
||||||
|
hparams,
|
||||||
|
prototypes_initializer=pt.initializers.SMCI(train_ds),
|
||||||
|
omega_initializer=-pt.initializers.PCALinearTransformInitializer(
|
||||||
|
train_ds))
|
||||||
|
|
||||||
|
# Compute intermediate input and output sizes
|
||||||
|
model.example_input_array = torch.zeros(4, 2)
|
||||||
|
|
||||||
|
# Summary
|
||||||
|
print(model)
|
||||||
|
|
||||||
|
# Callbacks
|
||||||
|
vis = pt.models.VisGLVQ2D(data=train_ds)
|
||||||
|
es = pl.callbacks.EarlyStopping(
|
||||||
|
monitor="train_acc",
|
||||||
|
min_delta=0.001,
|
||||||
|
patience=20,
|
||||||
|
mode="max",
|
||||||
|
verbose=False,
|
||||||
|
check_on_train_epoch_end=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Setup trainer
|
||||||
|
trainer = pl.Trainer.from_argparse_args(
|
||||||
|
args,
|
||||||
|
callbacks=[
|
||||||
|
vis,
|
||||||
|
es,
|
||||||
|
],
|
||||||
|
weights_summary="full",
|
||||||
|
accelerator="ddp",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Training loop
|
||||||
|
trainer.fit(model, train_loader)
|
@ -8,6 +8,7 @@ from .glvq import (
|
|||||||
GLVQ21,
|
GLVQ21,
|
||||||
GMLVQ,
|
GMLVQ,
|
||||||
GRLVQ,
|
GRLVQ,
|
||||||
|
GTLVQ,
|
||||||
LGMLVQ,
|
LGMLVQ,
|
||||||
LVQMLN,
|
LVQMLN,
|
||||||
ImageGLVQ,
|
ImageGLVQ,
|
||||||
|
@ -15,6 +15,44 @@ def rank_scaled_gaussian(distances, lambd):
|
|||||||
return torch.exp(-torch.exp(-ranks / lambd) * distances)
|
return torch.exp(-torch.exp(-ranks / lambd) * distances)
|
||||||
|
|
||||||
|
|
||||||
|
def orthogonalization(tensors):
|
||||||
|
"""Orthogonalization via polar decomposition """
|
||||||
|
u, _, v = torch.svd(tensors, compute_uv=True)
|
||||||
|
u_shape = tuple(list(u.shape))
|
||||||
|
v_shape = tuple(list(v.shape))
|
||||||
|
|
||||||
|
# reshape to (num x N x M)
|
||||||
|
u = torch.reshape(u, (-1, u_shape[-2], u_shape[-1]))
|
||||||
|
v = torch.reshape(v, (-1, v_shape[-2], v_shape[-1]))
|
||||||
|
|
||||||
|
out = u @ v.permute([0, 2, 1])
|
||||||
|
|
||||||
|
out = torch.reshape(out, u_shape[:-1] + (v_shape[-2], ))
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def ltangent_distance(x, y, omegas):
|
||||||
|
r"""Localized Tangent distance.
|
||||||
|
Compute Orthogonal Complement: math:`\bm P_k = \bm I - \Omega_k \Omega_k^T`
|
||||||
|
Compute Tangent Distance: math:`{\| \bm P \bm x - \bm P_k \bm y_k \|}_2`
|
||||||
|
|
||||||
|
:param `torch.tensor` omegas: Three dimensional matrix
|
||||||
|
:rtype: `torch.tensor`
|
||||||
|
"""
|
||||||
|
x, y = [arr.view(arr.size(0), -1) for arr in (x, y)]
|
||||||
|
p = torch.eye(omegas.shape[-2], device=omegas.device) - torch.bmm(
|
||||||
|
omegas, omegas.permute([0, 2, 1]))
|
||||||
|
projected_x = x @ p
|
||||||
|
projected_y = torch.diagonal(y @ p).T
|
||||||
|
expanded_y = torch.unsqueeze(projected_y, dim=1)
|
||||||
|
batchwise_difference = expanded_y - projected_x
|
||||||
|
differences_squared = batchwise_difference**2
|
||||||
|
distances = torch.sqrt(torch.sum(differences_squared, dim=2))
|
||||||
|
distances = distances.permute(1, 0)
|
||||||
|
return distances
|
||||||
|
|
||||||
|
|
||||||
class GaussianPrior(torch.nn.Module):
|
class GaussianPrior(torch.nn.Module):
|
||||||
def __init__(self, variance):
|
def __init__(self, variance):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -10,6 +10,7 @@ from ..core.losses import GLVQLoss, lvq1_loss, lvq21_loss
|
|||||||
from ..core.transforms import LinearTransform
|
from ..core.transforms import LinearTransform
|
||||||
from ..nn.wrappers import LambdaLayer, LossLayer
|
from ..nn.wrappers import LambdaLayer, LossLayer
|
||||||
from .abstract import ImagePrototypesMixin, SupervisedPrototypeModel
|
from .abstract import ImagePrototypesMixin, SupervisedPrototypeModel
|
||||||
|
from .extras import ltangent_distance, orthogonalization
|
||||||
|
|
||||||
|
|
||||||
class GLVQ(SupervisedPrototypeModel):
|
class GLVQ(SupervisedPrototypeModel):
|
||||||
@ -282,6 +283,30 @@ class LGMLVQ(GMLVQ):
|
|||||||
self.register_parameter("_omega", Parameter(omega))
|
self.register_parameter("_omega", Parameter(omega))
|
||||||
|
|
||||||
|
|
||||||
|
class GTLVQ(LGMLVQ):
|
||||||
|
"""Localized and Generalized Matrix Learning Vector Quantization."""
|
||||||
|
def __init__(self, hparams, **kwargs):
|
||||||
|
distance_fn = kwargs.pop("distance_fn", ltangent_distance)
|
||||||
|
super().__init__(hparams, distance_fn=distance_fn, **kwargs)
|
||||||
|
|
||||||
|
omega_initializer = kwargs.get("omega_initializer")
|
||||||
|
omega = omega_initializer.generate(self.hparams.input_dim,
|
||||||
|
self.hparams.latent_dim)
|
||||||
|
|
||||||
|
# Re-register `_omega` to override the one from the super class.
|
||||||
|
omega = torch.rand(
|
||||||
|
self.num_prototypes,
|
||||||
|
self.hparams.input_dim,
|
||||||
|
self.hparams.latent_dim,
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
|
self.register_parameter("_omega", Parameter(omega))
|
||||||
|
|
||||||
|
def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
|
||||||
|
with torch.no_grad():
|
||||||
|
self._omega.copy_(orthogonalization(self._omega))
|
||||||
|
|
||||||
|
|
||||||
class GLVQ1(GLVQ):
|
class GLVQ1(GLVQ):
|
||||||
"""Generalized Learning Vector Quantization 1."""
|
"""Generalized Learning Vector Quantization 1."""
|
||||||
def __init__(self, hparams, **kwargs):
|
def __init__(self, hparams, **kwargs):
|
||||||
|
Loading…
Reference in New Issue
Block a user