feat: ImageGTLVQ and SiameseGTLVQ with examples
This commit is contained in:
parent
d3bb430104
commit
a9edf06507
104
examples/gtlvq_mnist.py
Normal file
104
examples/gtlvq_mnist.py
Normal file
@ -0,0 +1,104 @@
|
|||||||
|
"""GMLVQ example using the MNIST dataset."""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
import prototorch as pt
|
||||||
|
import pytorch_lightning as pl
|
||||||
|
import torch
|
||||||
|
from torchvision import transforms
|
||||||
|
from torchvision.datasets import MNIST
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# Command-line arguments
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser = pl.Trainer.add_argparse_args(parser)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Dataset
|
||||||
|
train_ds = MNIST(
|
||||||
|
"~/datasets",
|
||||||
|
train=True,
|
||||||
|
download=True,
|
||||||
|
transform=transforms.Compose([
|
||||||
|
transforms.ToTensor(),
|
||||||
|
]),
|
||||||
|
)
|
||||||
|
test_ds = MNIST(
|
||||||
|
"~/datasets",
|
||||||
|
train=False,
|
||||||
|
download=True,
|
||||||
|
transform=transforms.Compose([
|
||||||
|
transforms.ToTensor(),
|
||||||
|
]),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Dataloaders
|
||||||
|
train_loader = torch.utils.data.DataLoader(train_ds,
|
||||||
|
num_workers=0,
|
||||||
|
batch_size=256)
|
||||||
|
test_loader = torch.utils.data.DataLoader(test_ds,
|
||||||
|
num_workers=0,
|
||||||
|
batch_size=256)
|
||||||
|
|
||||||
|
# Hyperparameters
|
||||||
|
num_classes = 10
|
||||||
|
prototypes_per_class = 1
|
||||||
|
hparams = dict(
|
||||||
|
input_dim=28 * 28,
|
||||||
|
latent_dim=28,
|
||||||
|
distribution=(num_classes, prototypes_per_class),
|
||||||
|
proto_lr=0.01,
|
||||||
|
bb_lr=0.01,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize the model
|
||||||
|
model = pt.models.ImageGTLVQ(
|
||||||
|
hparams,
|
||||||
|
optimizer=torch.optim.Adam,
|
||||||
|
prototypes_initializer=pt.initializers.SMCI(train_ds),
|
||||||
|
# Use one batch of data for subspace initiator.
|
||||||
|
# omega_initializer=pt.initializers.PCALinearTransformInitializer(next(iter(train_loader))[0].reshape(256,28*28))
|
||||||
|
)
|
||||||
|
|
||||||
|
# Callbacks
|
||||||
|
vis = pt.models.VisImgComp(
|
||||||
|
data=train_ds,
|
||||||
|
num_columns=10,
|
||||||
|
show=False,
|
||||||
|
tensorboard=True,
|
||||||
|
random_data=100,
|
||||||
|
add_embedding=True,
|
||||||
|
embedding_data=200,
|
||||||
|
flatten_data=False,
|
||||||
|
)
|
||||||
|
pruning = pt.models.PruneLoserPrototypes(
|
||||||
|
threshold=0.01,
|
||||||
|
idle_epochs=1,
|
||||||
|
prune_quota_per_epoch=10,
|
||||||
|
frequency=1,
|
||||||
|
verbose=True,
|
||||||
|
)
|
||||||
|
es = pl.callbacks.EarlyStopping(
|
||||||
|
monitor="train_loss",
|
||||||
|
min_delta=0.001,
|
||||||
|
patience=15,
|
||||||
|
mode="min",
|
||||||
|
check_on_train_epoch_end=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Setup trainer
|
||||||
|
# using GPUs here is strongly recommended!
|
||||||
|
trainer = pl.Trainer.from_argparse_args(
|
||||||
|
args,
|
||||||
|
callbacks=[
|
||||||
|
vis,
|
||||||
|
pruning,
|
||||||
|
# es,
|
||||||
|
],
|
||||||
|
terminate_on_nan=True,
|
||||||
|
weights_summary=None,
|
||||||
|
accelerator="ddp",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Training loop
|
||||||
|
trainer.fit(model, train_loader)
|
@ -24,79 +24,12 @@ if __name__ == "__main__":
|
|||||||
shuffle=True)
|
shuffle=True)
|
||||||
|
|
||||||
# Hyperparameters
|
# Hyperparameters
|
||||||
hparams = dict(distribution=[1, 3], input_dim=2, latent_dim=2)
|
# Latent_dim should be lower than input dim.
|
||||||
|
hparams = dict(distribution=[1, 3], input_dim=2, latent_dim=1)
|
||||||
|
|
||||||
# Initialize the model
|
# Initialize the model
|
||||||
model = pt.models.GTLVQ(
|
model = pt.models.GTLVQ(
|
||||||
hparams,
|
hparams, prototypes_initializer=pt.initializers.SMCI(train_ds))
|
||||||
prototypes_initializer=pt.initializers.SMCI(train_ds),
|
|
||||||
omega_initializer=-pt.initializers.PCALinearTransformInitializer(
|
|
||||||
train_ds))
|
|
||||||
|
|
||||||
# Compute intermediate input and output sizes
|
|
||||||
model.example_input_array = torch.zeros(4, 2)
|
|
||||||
|
|
||||||
# Summary
|
|
||||||
print(model)
|
|
||||||
|
|
||||||
# Callbacks
|
|
||||||
vis = pt.models.VisGLVQ2D(data=train_ds)
|
|
||||||
es = pl.callbacks.EarlyStopping(
|
|
||||||
monitor="train_acc",
|
|
||||||
min_delta=0.001,
|
|
||||||
patience=20,
|
|
||||||
mode="max",
|
|
||||||
verbose=False,
|
|
||||||
check_on_train_epoch_end=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Setup trainer
|
|
||||||
trainer = pl.Trainer.from_argparse_args(
|
|
||||||
args,
|
|
||||||
callbacks=[
|
|
||||||
vis,
|
|
||||||
es,
|
|
||||||
],
|
|
||||||
weights_summary="full",
|
|
||||||
accelerator="ddp",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Training loop
|
|
||||||
trainer.fit(model, train_loader)
|
|
||||||
"""Localized-GMLVQ example using the Moons dataset."""
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
|
|
||||||
import prototorch as pt
|
|
||||||
import pytorch_lightning as pl
|
|
||||||
import torch
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
# Command-line arguments
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser = pl.Trainer.add_argparse_args(parser)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
# Reproducibility
|
|
||||||
pl.utilities.seed.seed_everything(seed=2)
|
|
||||||
|
|
||||||
# Dataset
|
|
||||||
train_ds = pt.datasets.Moons(num_samples=300, noise=0.2, seed=42)
|
|
||||||
|
|
||||||
# Dataloaders
|
|
||||||
train_loader = torch.utils.data.DataLoader(train_ds,
|
|
||||||
batch_size=256,
|
|
||||||
shuffle=True)
|
|
||||||
|
|
||||||
# Hyperparameters
|
|
||||||
hparams = dict(distribution=[1, 3], input_dim=2, latent_dim=2)
|
|
||||||
|
|
||||||
# Initialize the model
|
|
||||||
model = pt.models.GTLVQ(
|
|
||||||
hparams,
|
|
||||||
prototypes_initializer=pt.initializers.SMCI(train_ds),
|
|
||||||
omega_initializer=-pt.initializers.PCALinearTransformInitializer(
|
|
||||||
train_ds))
|
|
||||||
|
|
||||||
# Compute intermediate input and output sizes
|
# Compute intermediate input and output sizes
|
||||||
model.example_input_array = torch.zeros(4, 2)
|
model.example_input_array = torch.zeros(4, 2)
|
||||||
|
72
examples/siamese_gtlvq_iris.py
Normal file
72
examples/siamese_gtlvq_iris.py
Normal file
@ -0,0 +1,72 @@
|
|||||||
|
"""Siamese GLVQ example using all four dimensions of the Iris dataset."""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
import prototorch as pt
|
||||||
|
import pytorch_lightning as pl
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class Backbone(torch.nn.Module):
|
||||||
|
def __init__(self, input_size=4, hidden_size=10, latent_size=2):
|
||||||
|
super().__init__()
|
||||||
|
self.input_size = input_size
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.latent_size = latent_size
|
||||||
|
self.dense1 = torch.nn.Linear(self.input_size, self.hidden_size)
|
||||||
|
self.dense2 = torch.nn.Linear(self.hidden_size, self.latent_size)
|
||||||
|
self.activation = torch.nn.Sigmoid()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.activation(self.dense1(x))
|
||||||
|
out = self.activation(self.dense2(x))
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# Command-line arguments
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser = pl.Trainer.add_argparse_args(parser)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Dataset
|
||||||
|
train_ds = pt.datasets.Iris()
|
||||||
|
|
||||||
|
# Reproducibility
|
||||||
|
pl.utilities.seed.seed_everything(seed=2)
|
||||||
|
|
||||||
|
# Dataloaders
|
||||||
|
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=150)
|
||||||
|
|
||||||
|
# Hyperparameters
|
||||||
|
hparams = dict(distribution=[1, 2, 3],
|
||||||
|
proto_lr=0.01,
|
||||||
|
bb_lr=0.01,
|
||||||
|
input_dim=2,
|
||||||
|
latent_dim=1)
|
||||||
|
|
||||||
|
# Initialize the backbone
|
||||||
|
backbone = Backbone(latent_size=hparams["input_dim"])
|
||||||
|
|
||||||
|
# Initialize the model
|
||||||
|
model = pt.models.SiameseGTLVQ(
|
||||||
|
hparams,
|
||||||
|
prototypes_initializer=pt.initializers.SMCI(train_ds),
|
||||||
|
backbone=backbone,
|
||||||
|
both_path_gradients=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Model summary
|
||||||
|
print(model)
|
||||||
|
|
||||||
|
# Callbacks
|
||||||
|
vis = pt.models.VisSiameseGLVQ2D(data=train_ds, border=0.1)
|
||||||
|
|
||||||
|
# Setup trainer
|
||||||
|
trainer = pl.Trainer.from_argparse_args(
|
||||||
|
args,
|
||||||
|
callbacks=[vis],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Training loop
|
||||||
|
trainer.fit(model, train_loader)
|
@ -13,8 +13,10 @@ from .glvq import (
|
|||||||
LVQMLN,
|
LVQMLN,
|
||||||
ImageGLVQ,
|
ImageGLVQ,
|
||||||
ImageGMLVQ,
|
ImageGMLVQ,
|
||||||
|
ImageGTLVQ,
|
||||||
SiameseGLVQ,
|
SiameseGLVQ,
|
||||||
SiameseGMLVQ,
|
SiameseGMLVQ,
|
||||||
|
SiameseGTLVQ,
|
||||||
)
|
)
|
||||||
from .knn import KNN
|
from .knn import KNN
|
||||||
from .lvq import LVQ1, LVQ21, MedianLVQ
|
from .lvq import LVQ1, LVQ21, MedianLVQ
|
||||||
|
@ -284,22 +284,28 @@ class LGMLVQ(GMLVQ):
|
|||||||
|
|
||||||
|
|
||||||
class GTLVQ(LGMLVQ):
|
class GTLVQ(LGMLVQ):
|
||||||
"""Localized and Generalized Matrix Learning Vector Quantization."""
|
"""Localized and Generalized Tangent Learning Vector Quantization."""
|
||||||
def __init__(self, hparams, **kwargs):
|
def __init__(self, hparams, **kwargs):
|
||||||
distance_fn = kwargs.pop("distance_fn", ltangent_distance)
|
distance_fn = kwargs.pop("distance_fn", ltangent_distance)
|
||||||
super().__init__(hparams, distance_fn=distance_fn, **kwargs)
|
super().__init__(hparams, distance_fn=distance_fn, **kwargs)
|
||||||
|
|
||||||
omega_initializer = kwargs.get("omega_initializer")
|
omega_initializer = kwargs.get("omega_initializer")
|
||||||
omega = omega_initializer.generate(self.hparams.input_dim,
|
|
||||||
self.hparams.latent_dim)
|
if omega_initializer is not None:
|
||||||
|
subspace = omega_initializer.generate(self.hparams.input_dim,
|
||||||
|
self.hparams.latent_dim)
|
||||||
|
omega = torch.repeat_interleave(subspace.unsqueeze(0),
|
||||||
|
self.num_prototypes,
|
||||||
|
dim=0)
|
||||||
|
else:
|
||||||
|
omega = torch.rand(
|
||||||
|
self.num_prototypes,
|
||||||
|
self.hparams.input_dim,
|
||||||
|
self.hparams.latent_dim,
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
|
|
||||||
# Re-register `_omega` to override the one from the super class.
|
# Re-register `_omega` to override the one from the super class.
|
||||||
omega = torch.rand(
|
|
||||||
self.num_prototypes,
|
|
||||||
self.hparams.input_dim,
|
|
||||||
self.hparams.latent_dim,
|
|
||||||
device=self.device,
|
|
||||||
)
|
|
||||||
self.register_parameter("_omega", Parameter(omega))
|
self.register_parameter("_omega", Parameter(omega))
|
||||||
|
|
||||||
def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
|
def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
|
||||||
@ -307,6 +313,14 @@ class GTLVQ(LGMLVQ):
|
|||||||
self._omega.copy_(orthogonalization(self._omega))
|
self._omega.copy_(orthogonalization(self._omega))
|
||||||
|
|
||||||
|
|
||||||
|
class SiameseGTLVQ(SiameseGLVQ, GTLVQ):
|
||||||
|
"""Generalized Tangent Learning Vector Quantization.
|
||||||
|
|
||||||
|
Implemented as a Siamese network with a linear transformation backbone.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
class GLVQ1(GLVQ):
|
class GLVQ1(GLVQ):
|
||||||
"""Generalized Learning Vector Quantization 1."""
|
"""Generalized Learning Vector Quantization 1."""
|
||||||
def __init__(self, hparams, **kwargs):
|
def __init__(self, hparams, **kwargs):
|
||||||
@ -339,3 +353,17 @@ class ImageGMLVQ(ImagePrototypesMixin, GMLVQ):
|
|||||||
after updates.
|
after updates.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class ImageGTLVQ(ImagePrototypesMixin, GTLVQ):
|
||||||
|
"""GTLVQ for training on image data.
|
||||||
|
|
||||||
|
GTLVQ model that constrains the prototypes to the range [0, 1] by clamping
|
||||||
|
after updates.
|
||||||
|
|
||||||
|
"""
|
||||||
|
def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
|
||||||
|
"""Constrain the components to the range [0, 1] by clamping after updates."""
|
||||||
|
self.proto_layer.components.data.clamp_(0.0, 1.0)
|
||||||
|
with torch.no_grad():
|
||||||
|
self._omega.copy_(orthogonalization(self._omega))
|
||||||
|
Loading…
Reference in New Issue
Block a user