feat: gtlvq with examples
This commit is contained in:
parent
6ffd27d12a
commit
d3bb430104
130
examples/gtlvq_moons.py
Normal file
130
examples/gtlvq_moons.py
Normal file
@ -0,0 +1,130 @@
|
||||
"""Localized-GMLVQ example using the Moons dataset."""
|
||||
|
||||
import argparse
|
||||
|
||||
import prototorch as pt
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Command-line arguments
|
||||
parser = argparse.ArgumentParser()
|
||||
parser = pl.Trainer.add_argparse_args(parser)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Reproducibility
|
||||
pl.utilities.seed.seed_everything(seed=2)
|
||||
|
||||
# Dataset
|
||||
train_ds = pt.datasets.Moons(num_samples=300, noise=0.2, seed=42)
|
||||
|
||||
# Dataloaders
|
||||
train_loader = torch.utils.data.DataLoader(train_ds,
|
||||
batch_size=256,
|
||||
shuffle=True)
|
||||
|
||||
# Hyperparameters
|
||||
hparams = dict(distribution=[1, 3], input_dim=2, latent_dim=2)
|
||||
|
||||
# Initialize the model
|
||||
model = pt.models.GTLVQ(
|
||||
hparams,
|
||||
prototypes_initializer=pt.initializers.SMCI(train_ds),
|
||||
omega_initializer=-pt.initializers.PCALinearTransformInitializer(
|
||||
train_ds))
|
||||
|
||||
# Compute intermediate input and output sizes
|
||||
model.example_input_array = torch.zeros(4, 2)
|
||||
|
||||
# Summary
|
||||
print(model)
|
||||
|
||||
# Callbacks
|
||||
vis = pt.models.VisGLVQ2D(data=train_ds)
|
||||
es = pl.callbacks.EarlyStopping(
|
||||
monitor="train_acc",
|
||||
min_delta=0.001,
|
||||
patience=20,
|
||||
mode="max",
|
||||
verbose=False,
|
||||
check_on_train_epoch_end=True,
|
||||
)
|
||||
|
||||
# Setup trainer
|
||||
trainer = pl.Trainer.from_argparse_args(
|
||||
args,
|
||||
callbacks=[
|
||||
vis,
|
||||
es,
|
||||
],
|
||||
weights_summary="full",
|
||||
accelerator="ddp",
|
||||
)
|
||||
|
||||
# Training loop
|
||||
trainer.fit(model, train_loader)
|
||||
"""Localized-GMLVQ example using the Moons dataset."""
|
||||
|
||||
import argparse
|
||||
|
||||
import prototorch as pt
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Command-line arguments
|
||||
parser = argparse.ArgumentParser()
|
||||
parser = pl.Trainer.add_argparse_args(parser)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Reproducibility
|
||||
pl.utilities.seed.seed_everything(seed=2)
|
||||
|
||||
# Dataset
|
||||
train_ds = pt.datasets.Moons(num_samples=300, noise=0.2, seed=42)
|
||||
|
||||
# Dataloaders
|
||||
train_loader = torch.utils.data.DataLoader(train_ds,
|
||||
batch_size=256,
|
||||
shuffle=True)
|
||||
|
||||
# Hyperparameters
|
||||
hparams = dict(distribution=[1, 3], input_dim=2, latent_dim=2)
|
||||
|
||||
# Initialize the model
|
||||
model = pt.models.GTLVQ(
|
||||
hparams,
|
||||
prototypes_initializer=pt.initializers.SMCI(train_ds),
|
||||
omega_initializer=-pt.initializers.PCALinearTransformInitializer(
|
||||
train_ds))
|
||||
|
||||
# Compute intermediate input and output sizes
|
||||
model.example_input_array = torch.zeros(4, 2)
|
||||
|
||||
# Summary
|
||||
print(model)
|
||||
|
||||
# Callbacks
|
||||
vis = pt.models.VisGLVQ2D(data=train_ds)
|
||||
es = pl.callbacks.EarlyStopping(
|
||||
monitor="train_acc",
|
||||
min_delta=0.001,
|
||||
patience=20,
|
||||
mode="max",
|
||||
verbose=False,
|
||||
check_on_train_epoch_end=True,
|
||||
)
|
||||
|
||||
# Setup trainer
|
||||
trainer = pl.Trainer.from_argparse_args(
|
||||
args,
|
||||
callbacks=[
|
||||
vis,
|
||||
es,
|
||||
],
|
||||
weights_summary="full",
|
||||
accelerator="ddp",
|
||||
)
|
||||
|
||||
# Training loop
|
||||
trainer.fit(model, train_loader)
|
@ -8,6 +8,7 @@ from .glvq import (
|
||||
GLVQ21,
|
||||
GMLVQ,
|
||||
GRLVQ,
|
||||
GTLVQ,
|
||||
LGMLVQ,
|
||||
LVQMLN,
|
||||
ImageGLVQ,
|
||||
|
@ -15,6 +15,44 @@ def rank_scaled_gaussian(distances, lambd):
|
||||
return torch.exp(-torch.exp(-ranks / lambd) * distances)
|
||||
|
||||
|
||||
def orthogonalization(tensors):
|
||||
"""Orthogonalization via polar decomposition """
|
||||
u, _, v = torch.svd(tensors, compute_uv=True)
|
||||
u_shape = tuple(list(u.shape))
|
||||
v_shape = tuple(list(v.shape))
|
||||
|
||||
# reshape to (num x N x M)
|
||||
u = torch.reshape(u, (-1, u_shape[-2], u_shape[-1]))
|
||||
v = torch.reshape(v, (-1, v_shape[-2], v_shape[-1]))
|
||||
|
||||
out = u @ v.permute([0, 2, 1])
|
||||
|
||||
out = torch.reshape(out, u_shape[:-1] + (v_shape[-2], ))
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def ltangent_distance(x, y, omegas):
|
||||
r"""Localized Tangent distance.
|
||||
Compute Orthogonal Complement: math:`\bm P_k = \bm I - \Omega_k \Omega_k^T`
|
||||
Compute Tangent Distance: math:`{\| \bm P \bm x - \bm P_k \bm y_k \|}_2`
|
||||
|
||||
:param `torch.tensor` omegas: Three dimensional matrix
|
||||
:rtype: `torch.tensor`
|
||||
"""
|
||||
x, y = [arr.view(arr.size(0), -1) for arr in (x, y)]
|
||||
p = torch.eye(omegas.shape[-2], device=omegas.device) - torch.bmm(
|
||||
omegas, omegas.permute([0, 2, 1]))
|
||||
projected_x = x @ p
|
||||
projected_y = torch.diagonal(y @ p).T
|
||||
expanded_y = torch.unsqueeze(projected_y, dim=1)
|
||||
batchwise_difference = expanded_y - projected_x
|
||||
differences_squared = batchwise_difference**2
|
||||
distances = torch.sqrt(torch.sum(differences_squared, dim=2))
|
||||
distances = distances.permute(1, 0)
|
||||
return distances
|
||||
|
||||
|
||||
class GaussianPrior(torch.nn.Module):
|
||||
def __init__(self, variance):
|
||||
super().__init__()
|
||||
|
@ -10,6 +10,7 @@ from ..core.losses import GLVQLoss, lvq1_loss, lvq21_loss
|
||||
from ..core.transforms import LinearTransform
|
||||
from ..nn.wrappers import LambdaLayer, LossLayer
|
||||
from .abstract import ImagePrototypesMixin, SupervisedPrototypeModel
|
||||
from .extras import ltangent_distance, orthogonalization
|
||||
|
||||
|
||||
class GLVQ(SupervisedPrototypeModel):
|
||||
@ -282,6 +283,30 @@ class LGMLVQ(GMLVQ):
|
||||
self.register_parameter("_omega", Parameter(omega))
|
||||
|
||||
|
||||
class GTLVQ(LGMLVQ):
|
||||
"""Localized and Generalized Matrix Learning Vector Quantization."""
|
||||
def __init__(self, hparams, **kwargs):
|
||||
distance_fn = kwargs.pop("distance_fn", ltangent_distance)
|
||||
super().__init__(hparams, distance_fn=distance_fn, **kwargs)
|
||||
|
||||
omega_initializer = kwargs.get("omega_initializer")
|
||||
omega = omega_initializer.generate(self.hparams.input_dim,
|
||||
self.hparams.latent_dim)
|
||||
|
||||
# Re-register `_omega` to override the one from the super class.
|
||||
omega = torch.rand(
|
||||
self.num_prototypes,
|
||||
self.hparams.input_dim,
|
||||
self.hparams.latent_dim,
|
||||
device=self.device,
|
||||
)
|
||||
self.register_parameter("_omega", Parameter(omega))
|
||||
|
||||
def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
|
||||
with torch.no_grad():
|
||||
self._omega.copy_(orthogonalization(self._omega))
|
||||
|
||||
|
||||
class GLVQ1(GLVQ):
|
||||
"""Generalized Learning Vector Quantization 1."""
|
||||
def __init__(self, hparams, **kwargs):
|
||||
|
Loading…
Reference in New Issue
Block a user