chore: update all examples to pytorch 1.6

This commit is contained in:
Alexander Engelsberger 2022-05-17 12:03:43 +02:00
parent c00513ae0d
commit 76fea3f881
No known key found for this signature in database
GPG Key ID: 72E54A9DAE51EB96
19 changed files with 453 additions and 200 deletions

View File

@ -1,12 +1,22 @@
"""CBC example using the Iris dataset.""" """CBC example using the Iris dataset."""
import argparse import argparse
import warnings
import prototorch as pt import prototorch as pt
import pytorch_lightning as pl import pytorch_lightning as pl
import torch from prototorch.models import CBC, VisCBC2D
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.utilities.warnings import PossibleUserWarning
from torch.utils.data import DataLoader
warnings.filterwarnings("ignore", category=PossibleUserWarning)
warnings.filterwarnings("ignore", category=UserWarning)
if __name__ == "__main__": if __name__ == "__main__":
# Reproducibility
seed_everything(seed=4)
# Command-line arguments # Command-line arguments
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser) parser = pl.Trainer.add_argparse_args(parser)
@ -15,11 +25,8 @@ if __name__ == "__main__":
# Dataset # Dataset
train_ds = pt.datasets.Iris(dims=[0, 2]) train_ds = pt.datasets.Iris(dims=[0, 2])
# Reproducibility
pl.utilities.seed.seed_everything(seed=42)
# Dataloaders # Dataloaders
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=32) train_loader = DataLoader(train_ds, batch_size=32)
# Hyperparameters # Hyperparameters
hparams = dict( hparams = dict(
@ -30,23 +37,30 @@ if __name__ == "__main__":
) )
# Initialize the model # Initialize the model
model = pt.models.CBC( model = CBC(
hparams, hparams,
components_initializer=pt.initializers.SSCI(train_ds, noise=0.01), components_initializer=pt.initializers.SSCI(train_ds, noise=0.1),
reasonings_iniitializer=pt.initializers. reasonings_initializer=pt.initializers.
PurePositiveReasoningsInitializer(), PurePositiveReasoningsInitializer(),
) )
# Callbacks # Callbacks
vis = pt.models.VisCBC2D(data=train_ds, vis = VisCBC2D(
title="CBC Iris Example", data=train_ds,
resolution=100, title="CBC Iris Example",
axis_off=True) resolution=100,
axis_off=True,
)
# Setup trainer # Setup trainer
trainer = pl.Trainer.from_argparse_args( trainer = pl.Trainer.from_argparse_args(
args, args,
callbacks=[vis], callbacks=[
vis,
],
detect_anomaly=True,
log_every_n_steps=1,
max_epochs=1000,
) )
# Training loop # Training loop

View File

@ -1,12 +1,29 @@
"""Dynamically prune 'loser' prototypes in GLVQ-type models.""" """Dynamically prune 'loser' prototypes in GLVQ-type models."""
import argparse import argparse
import logging
import warnings
import prototorch as pt import prototorch as pt
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
from prototorch.models import (
CELVQ,
PruneLoserPrototypes,
VisGLVQ2D,
)
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.utilities.warnings import PossibleUserWarning
from torch.utils.data import DataLoader
warnings.filterwarnings("ignore", category=PossibleUserWarning)
warnings.filterwarnings("ignore", category=UserWarning)
if __name__ == "__main__": if __name__ == "__main__":
# Reproducibility
seed_everything(seed=4)
# Command-line arguments # Command-line arguments
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser) parser = pl.Trainer.add_argparse_args(parser)
@ -16,15 +33,17 @@ if __name__ == "__main__":
num_classes = 4 num_classes = 4
num_features = 2 num_features = 2
num_clusters = 1 num_clusters = 1
train_ds = pt.datasets.Random(num_samples=500, train_ds = pt.datasets.Random(
num_classes=num_classes, num_samples=500,
num_features=num_features, num_classes=num_classes,
num_clusters=num_clusters, num_features=num_features,
separation=3.0, num_clusters=num_clusters,
seed=42) separation=3.0,
seed=42,
)
# Dataloaders # Dataloaders
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=256) train_loader = DataLoader(train_ds, batch_size=256)
# Hyperparameters # Hyperparameters
prototypes_per_class = num_clusters * 5 prototypes_per_class = num_clusters * 5
@ -34,7 +53,7 @@ if __name__ == "__main__":
) )
# Initialize the model # Initialize the model
model = pt.models.CELVQ( model = CELVQ(
hparams, hparams,
prototypes_initializer=pt.initializers.FVCI(2, 3.0), prototypes_initializer=pt.initializers.FVCI(2, 3.0),
) )
@ -43,18 +62,18 @@ if __name__ == "__main__":
model.example_input_array = torch.zeros(4, 2) model.example_input_array = torch.zeros(4, 2)
# Summary # Summary
print(model) logging.info(model)
# Callbacks # Callbacks
vis = pt.models.VisGLVQ2D(train_ds) vis = VisGLVQ2D(train_ds)
pruning = pt.models.PruneLoserPrototypes( pruning = PruneLoserPrototypes(
threshold=0.01, # prune prototype if it wins less than 1% threshold=0.01, # prune prototype if it wins less than 1%
idle_epochs=20, # pruning too early may cause problems idle_epochs=20, # pruning too early may cause problems
prune_quota_per_epoch=2, # prune at most 2 prototypes per epoch prune_quota_per_epoch=2, # prune at most 2 prototypes per epoch
frequency=1, # prune every epoch frequency=1, # prune every epoch
verbose=True, verbose=True,
) )
es = pl.callbacks.EarlyStopping( es = EarlyStopping(
monitor="train_loss", monitor="train_loss",
min_delta=0.001, min_delta=0.001,
patience=20, patience=20,
@ -71,10 +90,9 @@ if __name__ == "__main__":
pruning, pruning,
es, es,
], ],
progress_bar_refresh_rate=0, detect_anomaly=True,
terminate_on_nan=True, log_every_n_steps=1,
weights_summary="full", max_epochs=1000,
accelerator="ddp",
) )
# Training loop # Training loop

View File

@ -1,13 +1,24 @@
"""GLVQ example using the Iris dataset.""" """GLVQ example using the Iris dataset."""
import argparse import argparse
import logging
import warnings
import prototorch as pt import prototorch as pt
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
from prototorch.models import GLVQ, VisGLVQ2D
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.utilities.warnings import PossibleUserWarning
from torch.optim.lr_scheduler import ExponentialLR from torch.optim.lr_scheduler import ExponentialLR
from torch.utils.data import DataLoader
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=PossibleUserWarning)
if __name__ == "__main__": if __name__ == "__main__":
# Reproducibility
seed_everything(seed=4)
# Command-line arguments # Command-line arguments
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser) parser = pl.Trainer.add_argparse_args(parser)
@ -17,7 +28,7 @@ if __name__ == "__main__":
train_ds = pt.datasets.Iris(dims=[0, 2]) train_ds = pt.datasets.Iris(dims=[0, 2])
# Dataloaders # Dataloaders
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=64) train_loader = DataLoader(train_ds, batch_size=64, num_workers=4)
# Hyperparameters # Hyperparameters
hparams = dict( hparams = dict(
@ -29,7 +40,7 @@ if __name__ == "__main__":
) )
# Initialize the model # Initialize the model
model = pt.models.GLVQ( model = GLVQ(
hparams, hparams,
optimizer=torch.optim.Adam, optimizer=torch.optim.Adam,
prototypes_initializer=pt.initializers.SMCI(train_ds), prototypes_initializer=pt.initializers.SMCI(train_ds),
@ -41,13 +52,17 @@ if __name__ == "__main__":
model.example_input_array = torch.zeros(4, 2) model.example_input_array = torch.zeros(4, 2)
# Callbacks # Callbacks
vis = pt.models.VisGLVQ2D(data=train_ds) vis = VisGLVQ2D(data=train_ds)
# Setup trainer # Setup trainer
trainer = pl.Trainer.from_argparse_args( trainer = pl.Trainer.from_argparse_args(
args, args,
callbacks=[vis], callbacks=[
weights_summary="full", vis,
],
max_epochs=100,
log_every_n_steps=1,
detect_anomaly=True,
) )
# Training loop # Training loop
@ -57,8 +72,8 @@ if __name__ == "__main__":
trainer.save_checkpoint("./glvq_iris.ckpt") trainer.save_checkpoint("./glvq_iris.ckpt")
# Load saved model # Load saved model
new_model = pt.models.GLVQ.load_from_checkpoint( new_model = GLVQ.load_from_checkpoint(
checkpoint_path="./glvq_iris.ckpt", checkpoint_path="./glvq_iris.ckpt",
strict=False, strict=False,
) )
print(new_model) logging.info(new_model)

View File

@ -1,13 +1,25 @@
"""GMLVQ example using the Iris dataset.""" """GMLVQ example using the Iris dataset."""
import argparse import argparse
import warnings
import prototorch as pt import prototorch as pt
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
from prototorch.models import GMLVQ, VisGMLVQ2D
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.utilities.warnings import PossibleUserWarning
from torch.optim.lr_scheduler import ExponentialLR from torch.optim.lr_scheduler import ExponentialLR
from torch.utils.data import DataLoader
warnings.filterwarnings("ignore", category=PossibleUserWarning)
warnings.filterwarnings("ignore", category=UserWarning)
if __name__ == "__main__": if __name__ == "__main__":
# Reproducibility
seed_everything(seed=4)
# Command-line arguments # Command-line arguments
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser) parser = pl.Trainer.add_argparse_args(parser)
@ -17,7 +29,7 @@ if __name__ == "__main__":
train_ds = pt.datasets.Iris() train_ds = pt.datasets.Iris()
# Dataloaders # Dataloaders
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=64) train_loader = DataLoader(train_ds, batch_size=64)
# Hyperparameters # Hyperparameters
hparams = dict( hparams = dict(
@ -32,7 +44,7 @@ if __name__ == "__main__":
) )
# Initialize the model # Initialize the model
model = pt.models.GMLVQ( model = GMLVQ(
hparams, hparams,
optimizer=torch.optim.Adam, optimizer=torch.optim.Adam,
prototypes_initializer=pt.initializers.SMCI(train_ds), prototypes_initializer=pt.initializers.SMCI(train_ds),
@ -44,14 +56,17 @@ if __name__ == "__main__":
model.example_input_array = torch.zeros(4, 4) model.example_input_array = torch.zeros(4, 4)
# Callbacks # Callbacks
vis = pt.models.VisGMLVQ2D(data=train_ds) vis = VisGMLVQ2D(data=train_ds)
# Setup trainer # Setup trainer
trainer = pl.Trainer.from_argparse_args( trainer = pl.Trainer.from_argparse_args(
args, args,
callbacks=[vis], callbacks=[
weights_summary="full", vis,
accelerator="ddp", ],
max_epochs=100,
log_every_n_steps=1,
detect_anomaly=True,
) )
# Training loop # Training loop

View File

@ -1,14 +1,29 @@
"""GMLVQ example using the MNIST dataset.""" """GMLVQ example using the MNIST dataset."""
import argparse import argparse
import warnings
import prototorch as pt import prototorch as pt
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
from prototorch.models import (
ImageGMLVQ,
PruneLoserPrototypes,
VisImgComp,
)
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.utilities.warnings import PossibleUserWarning
from torch.utils.data import DataLoader
from torchvision import transforms from torchvision import transforms
from torchvision.datasets import MNIST from torchvision.datasets import MNIST
warnings.filterwarnings("ignore", category=PossibleUserWarning)
warnings.filterwarnings("ignore", category=UserWarning)
if __name__ == "__main__": if __name__ == "__main__":
# Reproducibility
seed_everything(seed=4)
# Command-line arguments # Command-line arguments
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser) parser = pl.Trainer.add_argparse_args(parser)
@ -33,12 +48,8 @@ if __name__ == "__main__":
) )
# Dataloaders # Dataloaders
train_loader = torch.utils.data.DataLoader(train_ds, train_loader = DataLoader(train_ds, num_workers=4, batch_size=256)
num_workers=0, test_loader = DataLoader(test_ds, num_workers=4, batch_size=256)
batch_size=256)
test_loader = torch.utils.data.DataLoader(test_ds,
num_workers=0,
batch_size=256)
# Hyperparameters # Hyperparameters
num_classes = 10 num_classes = 10
@ -52,14 +63,14 @@ if __name__ == "__main__":
) )
# Initialize the model # Initialize the model
model = pt.models.ImageGMLVQ( model = ImageGMLVQ(
hparams, hparams,
optimizer=torch.optim.Adam, optimizer=torch.optim.Adam,
prototypes_initializer=pt.initializers.SMCI(train_ds), prototypes_initializer=pt.initializers.SMCI(train_ds),
) )
# Callbacks # Callbacks
vis = pt.models.VisImgComp( vis = VisImgComp(
data=train_ds, data=train_ds,
num_columns=10, num_columns=10,
show=False, show=False,
@ -69,14 +80,14 @@ if __name__ == "__main__":
embedding_data=200, embedding_data=200,
flatten_data=False, flatten_data=False,
) )
pruning = pt.models.PruneLoserPrototypes( pruning = PruneLoserPrototypes(
threshold=0.01, threshold=0.01,
idle_epochs=1, idle_epochs=1,
prune_quota_per_epoch=10, prune_quota_per_epoch=10,
frequency=1, frequency=1,
verbose=True, verbose=True,
) )
es = pl.callbacks.EarlyStopping( es = EarlyStopping(
monitor="train_loss", monitor="train_loss",
min_delta=0.001, min_delta=0.001,
patience=15, patience=15,
@ -90,11 +101,11 @@ if __name__ == "__main__":
callbacks=[ callbacks=[
vis, vis,
pruning, pruning,
# es, es,
], ],
terminate_on_nan=True, max_epochs=1000,
weights_summary=None, log_every_n_steps=1,
# accelerator="ddp", detect_anomaly=True,
) )
# Training loop # Training loop

View File

@ -1,12 +1,28 @@
"""GMLVQ example using the spiral dataset.""" """GMLVQ example using the spiral dataset."""
import argparse import argparse
import warnings
import prototorch as pt import prototorch as pt
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
from prototorch.models import (
GMLVQ,
PruneLoserPrototypes,
VisGLVQ2D,
)
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.utilities.warnings import PossibleUserWarning
from torch.utils.data import DataLoader
warnings.filterwarnings("ignore", category=PossibleUserWarning)
warnings.filterwarnings("ignore", category=UserWarning)
if __name__ == "__main__": if __name__ == "__main__":
# Reproducibility
seed_everything(seed=4)
# Command-line arguments # Command-line arguments
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser) parser = pl.Trainer.add_argparse_args(parser)
@ -16,7 +32,7 @@ if __name__ == "__main__":
train_ds = pt.datasets.Spiral(num_samples=500, noise=0.5) train_ds = pt.datasets.Spiral(num_samples=500, noise=0.5)
# Dataloaders # Dataloaders
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=256) train_loader = DataLoader(train_ds, batch_size=256)
# Hyperparameters # Hyperparameters
num_classes = 2 num_classes = 2
@ -32,19 +48,19 @@ if __name__ == "__main__":
) )
# Initialize the model # Initialize the model
model = pt.models.GMLVQ( model = GMLVQ(
hparams, hparams,
optimizer=torch.optim.Adam, optimizer=torch.optim.Adam,
prototypes_initializer=pt.initializers.SSCI(train_ds, noise=1e-2), prototypes_initializer=pt.initializers.SSCI(train_ds, noise=1e-2),
) )
# Callbacks # Callbacks
vis = pt.models.VisGLVQ2D( vis = VisGLVQ2D(
train_ds, train_ds,
show_last_only=False, show_last_only=False,
block=False, block=False,
) )
pruning = pt.models.PruneLoserPrototypes( pruning = PruneLoserPrototypes(
threshold=0.01, threshold=0.01,
idle_epochs=10, idle_epochs=10,
prune_quota_per_epoch=5, prune_quota_per_epoch=5,
@ -53,7 +69,7 @@ if __name__ == "__main__":
prototypes_initializer=pt.initializers.SSCI(train_ds, noise=1e-1), prototypes_initializer=pt.initializers.SSCI(train_ds, noise=1e-1),
verbose=True, verbose=True,
) )
es = pl.callbacks.EarlyStopping( es = EarlyStopping(
monitor="train_loss", monitor="train_loss",
min_delta=1.0, min_delta=1.0,
patience=5, patience=5,
@ -69,7 +85,9 @@ if __name__ == "__main__":
es, es,
pruning, pruning,
], ],
terminate_on_nan=True, max_epochs=1000,
log_every_n_steps=1,
detect_anomaly=True,
) )
# Training loop # Training loop

View File

@ -1,10 +1,19 @@
"""Growing Neural Gas example using the Iris dataset.""" """Growing Neural Gas example using the Iris dataset."""
import argparse import argparse
import logging
import warnings
import prototorch as pt import prototorch as pt
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
from prototorch.models import GrowingNeuralGas, VisNG2D
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.utilities.warnings import PossibleUserWarning
from torch.utils.data import DataLoader
warnings.filterwarnings("ignore", category=PossibleUserWarning)
warnings.filterwarnings("ignore", category=UserWarning)
if __name__ == "__main__": if __name__ == "__main__":
# Command-line arguments # Command-line arguments
@ -13,11 +22,11 @@ if __name__ == "__main__":
args = parser.parse_args() args = parser.parse_args()
# Reproducibility # Reproducibility
pl.utilities.seed.seed_everything(seed=42) seed_everything(seed=42)
# Prepare the data # Prepare the data
train_ds = pt.datasets.Iris(dims=[0, 2]) train_ds = pt.datasets.Iris(dims=[0, 2])
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=64) train_loader = DataLoader(train_ds, batch_size=64)
# Hyperparameters # Hyperparameters
hparams = dict( hparams = dict(
@ -27,7 +36,7 @@ if __name__ == "__main__":
) )
# Initialize the model # Initialize the model
model = pt.models.GrowingNeuralGas( model = GrowingNeuralGas(
hparams, hparams,
prototypes_initializer=pt.initializers.ZCI(2), prototypes_initializer=pt.initializers.ZCI(2),
) )
@ -36,17 +45,20 @@ if __name__ == "__main__":
model.example_input_array = torch.zeros(4, 2) model.example_input_array = torch.zeros(4, 2)
# Model summary # Model summary
print(model) logging.info(model)
# Callbacks # Callbacks
vis = pt.models.VisNG2D(data=train_loader) vis = VisNG2D(data=train_loader)
# Setup trainer # Setup trainer
trainer = pl.Trainer.from_argparse_args( trainer = pl.Trainer.from_argparse_args(
args, args,
callbacks=[
vis,
],
max_epochs=100, max_epochs=100,
callbacks=[vis], log_every_n_steps=1,
weights_summary="full", detect_anomaly=True,
) )
# Training loop # Training loop

View File

@ -1,14 +1,30 @@
"""GTLVQ example using the MNIST dataset.""" """GTLVQ example using the MNIST dataset."""
import argparse import argparse
import warnings
import prototorch as pt import prototorch as pt
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
from prototorch.models import (
ImageGTLVQ,
PruneLoserPrototypes,
VisImgComp,
)
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.utilities.warnings import PossibleUserWarning
from torch.utils.data import DataLoader
from torchvision import transforms from torchvision import transforms
from torchvision.datasets import MNIST from torchvision.datasets import MNIST
warnings.filterwarnings("ignore", category=PossibleUserWarning)
warnings.filterwarnings("ignore", category=UserWarning)
if __name__ == "__main__": if __name__ == "__main__":
# Reproducibility
seed_everything(seed=4)
# Command-line arguments # Command-line arguments
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser) parser = pl.Trainer.add_argparse_args(parser)
@ -33,12 +49,8 @@ if __name__ == "__main__":
) )
# Dataloaders # Dataloaders
train_loader = torch.utils.data.DataLoader(train_ds, train_loader = DataLoader(train_ds, num_workers=0, batch_size=256)
num_workers=0, test_loader = DataLoader(test_ds, num_workers=0, batch_size=256)
batch_size=256)
test_loader = torch.utils.data.DataLoader(test_ds,
num_workers=0,
batch_size=256)
# Hyperparameters # Hyperparameters
num_classes = 10 num_classes = 10
@ -52,7 +64,7 @@ if __name__ == "__main__":
) )
# Initialize the model # Initialize the model
model = pt.models.ImageGTLVQ( model = ImageGTLVQ(
hparams, hparams,
optimizer=torch.optim.Adam, optimizer=torch.optim.Adam,
prototypes_initializer=pt.initializers.SMCI(train_ds), prototypes_initializer=pt.initializers.SMCI(train_ds),
@ -61,7 +73,7 @@ if __name__ == "__main__":
next(iter(train_loader))[0].reshape(256, 28 * 28))) next(iter(train_loader))[0].reshape(256, 28 * 28)))
# Callbacks # Callbacks
vis = pt.models.VisImgComp( vis = VisImgComp(
data=train_ds, data=train_ds,
num_columns=10, num_columns=10,
show=False, show=False,
@ -71,14 +83,14 @@ if __name__ == "__main__":
embedding_data=200, embedding_data=200,
flatten_data=False, flatten_data=False,
) )
pruning = pt.models.PruneLoserPrototypes( pruning = PruneLoserPrototypes(
threshold=0.01, threshold=0.01,
idle_epochs=1, idle_epochs=1,
prune_quota_per_epoch=10, prune_quota_per_epoch=10,
frequency=1, frequency=1,
verbose=True, verbose=True,
) )
es = pl.callbacks.EarlyStopping( es = EarlyStopping(
monitor="train_loss", monitor="train_loss",
min_delta=0.001, min_delta=0.001,
patience=15, patience=15,
@ -93,11 +105,11 @@ if __name__ == "__main__":
callbacks=[ callbacks=[
vis, vis,
pruning, pruning,
# es, es,
], ],
terminate_on_nan=True, max_epochs=1000,
weights_summary=None, log_every_n_steps=1,
accelerator="ddp", detect_anomaly=True,
) )
# Training loop # Training loop

View File

@ -1,10 +1,20 @@
"""Localized-GTLVQ example using the Moons dataset.""" """Localized-GTLVQ example using the Moons dataset."""
import argparse import argparse
import logging
import warnings
import prototorch as pt import prototorch as pt
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
from prototorch.models import GTLVQ, VisGLVQ2D
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.utilities.warnings import PossibleUserWarning
from torch.utils.data import DataLoader
warnings.filterwarnings("ignore", category=PossibleUserWarning)
warnings.filterwarnings("ignore", category=UserWarning)
if __name__ == "__main__": if __name__ == "__main__":
# Command-line arguments # Command-line arguments
@ -13,33 +23,35 @@ if __name__ == "__main__":
args = parser.parse_args() args = parser.parse_args()
# Reproducibility # Reproducibility
pl.utilities.seed.seed_everything(seed=2) seed_everything(seed=2)
# Dataset # Dataset
train_ds = pt.datasets.Moons(num_samples=300, noise=0.2, seed=42) train_ds = pt.datasets.Moons(num_samples=300, noise=0.2, seed=42)
# Dataloaders # Dataloaders
train_loader = torch.utils.data.DataLoader(train_ds, train_loader = DataLoader(
batch_size=256, train_ds,
shuffle=True) batch_size=256,
shuffle=True,
)
# Hyperparameters # Hyperparameters
# Latent_dim should be lower than input dim. # Latent_dim should be lower than input dim.
hparams = dict(distribution=[1, 3], input_dim=2, latent_dim=1) hparams = dict(distribution=[1, 3], input_dim=2, latent_dim=1)
# Initialize the model # Initialize the model
model = pt.models.GTLVQ( model = GTLVQ(hparams,
hparams, prototypes_initializer=pt.initializers.SMCI(train_ds)) prototypes_initializer=pt.initializers.SMCI(train_ds))
# Compute intermediate input and output sizes # Compute intermediate input and output sizes
model.example_input_array = torch.zeros(4, 2) model.example_input_array = torch.zeros(4, 2)
# Summary # Summary
print(model) logging.info(model)
# Callbacks # Callbacks
vis = pt.models.VisGLVQ2D(data=train_ds) vis = VisGLVQ2D(data=train_ds)
es = pl.callbacks.EarlyStopping( es = EarlyStopping(
monitor="train_acc", monitor="train_acc",
min_delta=0.001, min_delta=0.001,
patience=20, patience=20,
@ -55,8 +67,9 @@ if __name__ == "__main__":
vis, vis,
es, es,
], ],
weights_summary="full", max_epochs=1000,
accelerator="ddp", log_every_n_steps=1,
detect_anomaly=True,
) )
# Training loop # Training loop

View File

@ -1,12 +1,19 @@
"""k-NN example using the Iris dataset from scikit-learn.""" """k-NN example using the Iris dataset from scikit-learn."""
import argparse import argparse
import logging
import warnings
import prototorch as pt import prototorch as pt
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
from prototorch.models import KNN, VisGLVQ2D
from pytorch_lightning.utilities.warnings import PossibleUserWarning
from sklearn.datasets import load_iris from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader
warnings.filterwarnings("ignore", category=PossibleUserWarning)
if __name__ == "__main__": if __name__ == "__main__":
# Command-line arguments # Command-line arguments
@ -16,34 +23,36 @@ if __name__ == "__main__":
# Dataset # Dataset
X, y = load_iris(return_X_y=True) X, y = load_iris(return_X_y=True)
X = X[:, [0, 2]] X = X[:, 0:3:2]
X_train, X_test, y_train, y_test = train_test_split(X, X_train, X_test, y_train, y_test = train_test_split(
y, X,
test_size=0.5, y,
random_state=42) test_size=0.5,
random_state=42,
)
train_ds = pt.datasets.NumpyDataset(X_train, y_train) train_ds = pt.datasets.NumpyDataset(X_train, y_train)
test_ds = pt.datasets.NumpyDataset(X_test, y_test) test_ds = pt.datasets.NumpyDataset(X_test, y_test)
# Dataloaders # Dataloaders
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=16) train_loader = DataLoader(train_ds, batch_size=16)
test_loader = torch.utils.data.DataLoader(test_ds, batch_size=16) test_loader = DataLoader(test_ds, batch_size=16)
# Hyperparameters # Hyperparameters
hparams = dict(k=5) hparams = dict(k=5)
# Initialize the model # Initialize the model
model = pt.models.KNN(hparams, data=train_ds) model = KNN(hparams, data=train_ds)
# Compute intermediate input and output sizes # Compute intermediate input and output sizes
model.example_input_array = torch.zeros(4, 2) model.example_input_array = torch.zeros(4, 2)
# Summary # Summary
print(model) logging.info(model)
# Callbacks # Callbacks
vis = pt.models.VisGLVQ2D( vis = VisGLVQ2D(
data=(X_train, y_train), data=(X_train, y_train),
resolution=200, resolution=200,
block=True, block=True,
@ -53,8 +62,11 @@ if __name__ == "__main__":
trainer = pl.Trainer.from_argparse_args( trainer = pl.Trainer.from_argparse_args(
args, args,
max_epochs=1, max_epochs=1,
callbacks=[vis], callbacks=[
weights_summary="full", vis,
],
log_every_n_steps=1,
detect_anomaly=True,
) )
# Training loop # Training loop
@ -63,7 +75,7 @@ if __name__ == "__main__":
# Recall # Recall
y_pred = model.predict(torch.tensor(X_train)) y_pred = model.predict(torch.tensor(X_train))
print(y_pred) logging.info(y_pred)
# Test # Test
trainer.test(model, dataloaders=test_loader) trainer.test(model, dataloaders=test_loader)

View File

@ -1,12 +1,21 @@
"""Kohonen Self Organizing Map.""" """Kohonen Self Organizing Map."""
import argparse import argparse
import logging
import warnings
import prototorch as pt import prototorch as pt
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
from matplotlib import pyplot as plt from matplotlib import pyplot as plt
from prototorch.models import KohonenSOM
from prototorch.utils.colors import hex_to_rgb from prototorch.utils.colors import hex_to_rgb
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.utilities.warnings import PossibleUserWarning
from torch.utils.data import DataLoader, TensorDataset
warnings.filterwarnings("ignore", category=PossibleUserWarning)
warnings.filterwarnings("ignore", category=UserWarning)
class Vis2DColorSOM(pl.Callback): class Vis2DColorSOM(pl.Callback):
@ -18,7 +27,7 @@ class Vis2DColorSOM(pl.Callback):
self.data = data self.data = data
self.pause_time = pause_time self.pause_time = pause_time
def on_epoch_end(self, trainer, pl_module): def on_train_epoch_end(self, trainer, pl_module: KohonenSOM):
ax = self.fig.gca() ax = self.fig.gca()
ax.cla() ax.cla()
ax.set_title(self.title) ax.set_title(self.title)
@ -31,12 +40,14 @@ class Vis2DColorSOM(pl.Callback):
d = pl_module.compute_distances(self.data) d = pl_module.compute_distances(self.data)
wp = pl_module.predict_from_distances(d) wp = pl_module.predict_from_distances(d)
for i, iloc in enumerate(wp): for i, iloc in enumerate(wp):
plt.text(iloc[1], plt.text(
iloc[0], iloc[1],
cnames[i], iloc[0],
ha="center", color_names[i],
va="center", ha="center",
bbox=dict(facecolor="white", alpha=0.5, lw=0)) va="center",
bbox=dict(facecolor="white", alpha=0.5, lw=0),
)
if trainer.current_epoch != trainer.max_epochs - 1: if trainer.current_epoch != trainer.max_epochs - 1:
plt.pause(self.pause_time) plt.pause(self.pause_time)
@ -51,7 +62,7 @@ if __name__ == "__main__":
args = parser.parse_args() args = parser.parse_args()
# Reproducibility # Reproducibility
pl.utilities.seed.seed_everything(seed=42) seed_everything(seed=42)
# Prepare the data # Prepare the data
hex_colors = [ hex_colors = [
@ -59,15 +70,15 @@ if __name__ == "__main__":
"#00ff00", "#ff0000", "#00ffff", "#ff00ff", "#ffff00", "#ffffff", "#00ff00", "#ff0000", "#00ffff", "#ff00ff", "#ffff00", "#ffffff",
"#545454", "#7f7f7f", "#a8a8a8", "#808000", "#800080", "#ffa500" "#545454", "#7f7f7f", "#a8a8a8", "#808000", "#800080", "#ffa500"
] ]
cnames = [ color_names = [
"black", "blue", "darkblue", "skyblue", "greyblue", "lilac", "green", "black", "blue", "darkblue", "skyblue", "greyblue", "lilac", "green",
"red", "cyan", "magenta", "yellow", "white", "darkgrey", "mediumgrey", "red", "cyan", "magenta", "yellow", "white", "darkgrey", "mediumgrey",
"lightgrey", "olive", "purple", "orange" "lightgrey", "olive", "purple", "orange"
] ]
colors = list(hex_to_rgb(hex_colors)) colors = list(hex_to_rgb(hex_colors))
data = torch.Tensor(colors) / 255.0 data = torch.Tensor(colors) / 255.0
train_ds = torch.utils.data.TensorDataset(data) train_ds = TensorDataset(data)
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=8) train_loader = DataLoader(train_ds, batch_size=8)
# Hyperparameters # Hyperparameters
hparams = dict( hparams = dict(
@ -78,7 +89,7 @@ if __name__ == "__main__":
) )
# Initialize the model # Initialize the model
model = pt.models.KohonenSOM( model = KohonenSOM(
hparams, hparams,
prototypes_initializer=pt.initializers.RNCI(3), prototypes_initializer=pt.initializers.RNCI(3),
) )
@ -87,7 +98,7 @@ if __name__ == "__main__":
model.example_input_array = torch.zeros(4, 3) model.example_input_array = torch.zeros(4, 3)
# Model summary # Model summary
print(model) logging.info(model)
# Callbacks # Callbacks
vis = Vis2DColorSOM(data=data) vis = Vis2DColorSOM(data=data)
@ -96,8 +107,11 @@ if __name__ == "__main__":
trainer = pl.Trainer.from_argparse_args( trainer = pl.Trainer.from_argparse_args(
args, args,
max_epochs=500, max_epochs=500,
callbacks=[vis], callbacks=[
weights_summary="full", vis,
],
log_every_n_steps=1,
detect_anomaly=True,
) )
# Training loop # Training loop

View File

@ -1,10 +1,20 @@
"""Localized-GMLVQ example using the Moons dataset.""" """Localized-GMLVQ example using the Moons dataset."""
import argparse import argparse
import logging
import warnings
import prototorch as pt import prototorch as pt
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
from prototorch.models import LGMLVQ, VisGLVQ2D
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.utilities.warnings import PossibleUserWarning
from torch.utils.data import DataLoader
warnings.filterwarnings("ignore", category=PossibleUserWarning)
warnings.filterwarnings("ignore", category=UserWarning)
if __name__ == "__main__": if __name__ == "__main__":
# Command-line arguments # Command-line arguments
@ -13,15 +23,13 @@ if __name__ == "__main__":
args = parser.parse_args() args = parser.parse_args()
# Reproducibility # Reproducibility
pl.utilities.seed.seed_everything(seed=2) seed_everything(seed=2)
# Dataset # Dataset
train_ds = pt.datasets.Moons(num_samples=300, noise=0.2, seed=42) train_ds = pt.datasets.Moons(num_samples=300, noise=0.2, seed=42)
# Dataloaders # Dataloaders
train_loader = torch.utils.data.DataLoader(train_ds, train_loader = DataLoader(train_ds, batch_size=256, shuffle=True)
batch_size=256,
shuffle=True)
# Hyperparameters # Hyperparameters
hparams = dict( hparams = dict(
@ -31,7 +39,7 @@ if __name__ == "__main__":
) )
# Initialize the model # Initialize the model
model = pt.models.LGMLVQ( model = LGMLVQ(
hparams, hparams,
prototypes_initializer=pt.initializers.SMCI(train_ds), prototypes_initializer=pt.initializers.SMCI(train_ds),
) )
@ -40,11 +48,11 @@ if __name__ == "__main__":
model.example_input_array = torch.zeros(4, 2) model.example_input_array = torch.zeros(4, 2)
# Summary # Summary
print(model) logging.info(model)
# Callbacks # Callbacks
vis = pt.models.VisGLVQ2D(data=train_ds) vis = VisGLVQ2D(data=train_ds)
es = pl.callbacks.EarlyStopping( es = EarlyStopping(
monitor="train_acc", monitor="train_acc",
min_delta=0.001, min_delta=0.001,
patience=20, patience=20,
@ -60,8 +68,9 @@ if __name__ == "__main__":
vis, vis,
es, es,
], ],
weights_summary="full", log_every_n_steps=1,
accelerator="ddp", max_epochs=1000,
detect_anomaly=True,
) )
# Training loop # Training loop

View File

@ -1,10 +1,22 @@
"""LVQMLN example using all four dimensions of the Iris dataset.""" """LVQMLN example using all four dimensions of the Iris dataset."""
import argparse import argparse
import warnings
import prototorch as pt import prototorch as pt
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
from prototorch.models import (
LVQMLN,
PruneLoserPrototypes,
VisSiameseGLVQ2D,
)
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.utilities.warnings import PossibleUserWarning
from torch.utils.data import DataLoader
warnings.filterwarnings("ignore", category=PossibleUserWarning)
warnings.filterwarnings("ignore", category=UserWarning)
class Backbone(torch.nn.Module): class Backbone(torch.nn.Module):
@ -34,10 +46,10 @@ if __name__ == "__main__":
train_ds = pt.datasets.Iris() train_ds = pt.datasets.Iris()
# Reproducibility # Reproducibility
pl.utilities.seed.seed_everything(seed=42) seed_everything(seed=42)
# Dataloaders # Dataloaders
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=150) train_loader = DataLoader(train_ds, batch_size=150)
# Hyperparameters # Hyperparameters
hparams = dict( hparams = dict(
@ -50,7 +62,7 @@ if __name__ == "__main__":
backbone = Backbone() backbone = Backbone()
# Initialize the model # Initialize the model
model = pt.models.LVQMLN( model = LVQMLN(
hparams, hparams,
prototypes_initializer=pt.initializers.SSCI( prototypes_initializer=pt.initializers.SSCI(
train_ds, train_ds,
@ -59,18 +71,15 @@ if __name__ == "__main__":
backbone=backbone, backbone=backbone,
) )
# Model summary
print(model)
# Callbacks # Callbacks
vis = pt.models.VisSiameseGLVQ2D( vis = VisSiameseGLVQ2D(
data=train_ds, data=train_ds,
map_protos=False, map_protos=False,
border=0.1, border=0.1,
resolution=500, resolution=500,
axis_off=True, axis_off=True,
) )
pruning = pt.models.PruneLoserPrototypes( pruning = PruneLoserPrototypes(
threshold=0.01, threshold=0.01,
idle_epochs=20, idle_epochs=20,
prune_quota_per_epoch=2, prune_quota_per_epoch=2,
@ -85,6 +94,9 @@ if __name__ == "__main__":
vis, vis,
pruning, pruning,
], ],
log_every_n_steps=1,
max_epochs=1000,
detect_anomaly=True,
) )
# Training loop # Training loop

View File

@ -1,12 +1,23 @@
"""Median-LVQ example using the Iris dataset.""" """Median-LVQ example using the Iris dataset."""
import argparse import argparse
import warnings
import prototorch as pt import prototorch as pt
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
from prototorch.models import MedianLVQ, VisGLVQ2D
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.utilities.warnings import PossibleUserWarning
from torch.utils.data import DataLoader
warnings.filterwarnings("ignore", category=PossibleUserWarning)
warnings.filterwarnings("ignore", category=UserWarning)
if __name__ == "__main__": if __name__ == "__main__":
# Reproducibility
seed_everything(seed=4)
# Command-line arguments # Command-line arguments
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser) parser = pl.Trainer.add_argparse_args(parser)
@ -16,13 +27,13 @@ if __name__ == "__main__":
train_ds = pt.datasets.Iris(dims=[0, 2]) train_ds = pt.datasets.Iris(dims=[0, 2])
# Dataloaders # Dataloaders
train_loader = torch.utils.data.DataLoader( train_loader = DataLoader(
train_ds, train_ds,
batch_size=len(train_ds), # MedianLVQ cannot handle mini-batches batch_size=len(train_ds), # MedianLVQ cannot handle mini-batches
) )
# Initialize the model # Initialize the model
model = pt.models.MedianLVQ( model = MedianLVQ(
hparams=dict(distribution=(3, 2), lr=0.01), hparams=dict(distribution=(3, 2), lr=0.01),
prototypes_initializer=pt.initializers.SSCI(train_ds), prototypes_initializer=pt.initializers.SSCI(train_ds),
) )
@ -31,8 +42,8 @@ if __name__ == "__main__":
model.example_input_array = torch.zeros(4, 2) model.example_input_array = torch.zeros(4, 2)
# Callbacks # Callbacks
vis = pt.models.VisGLVQ2D(data=train_ds) vis = VisGLVQ2D(data=train_ds)
es = pl.callbacks.EarlyStopping( es = EarlyStopping(
monitor="train_acc", monitor="train_acc",
min_delta=0.01, min_delta=0.01,
patience=5, patience=5,
@ -44,8 +55,13 @@ if __name__ == "__main__":
# Setup trainer # Setup trainer
trainer = pl.Trainer.from_argparse_args( trainer = pl.Trainer.from_argparse_args(
args, args,
callbacks=[vis, es], callbacks=[
weights_summary="full", vis,
es,
],
max_epochs=1000,
log_every_n_steps=1,
detect_anomaly=True,
) )
# Training loop # Training loop

View File

@ -1,15 +1,26 @@
"""Neural Gas example using the Iris dataset.""" """Neural Gas example using the Iris dataset."""
import argparse import argparse
import warnings
import prototorch as pt import prototorch as pt
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
from prototorch.models import NeuralGas, VisNG2D
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.utilities.warnings import PossibleUserWarning
from sklearn.datasets import load_iris from sklearn.datasets import load_iris
from sklearn.preprocessing import StandardScaler from sklearn.preprocessing import StandardScaler
from torch.optim.lr_scheduler import ExponentialLR from torch.optim.lr_scheduler import ExponentialLR
from torch.utils.data import DataLoader
warnings.filterwarnings("ignore", category=PossibleUserWarning)
warnings.filterwarnings("ignore", category=UserWarning)
if __name__ == "__main__": if __name__ == "__main__":
# Reproducibility
seed_everything(seed=4)
# Command-line arguments # Command-line arguments
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser) parser = pl.Trainer.add_argparse_args(parser)
@ -17,7 +28,7 @@ if __name__ == "__main__":
# Prepare and pre-process the dataset # Prepare and pre-process the dataset
x_train, y_train = load_iris(return_X_y=True) x_train, y_train = load_iris(return_X_y=True)
x_train = x_train[:, [0, 2]] x_train = x_train[:, 0:3:2]
scaler = StandardScaler() scaler = StandardScaler()
scaler.fit(x_train) scaler.fit(x_train)
x_train = scaler.transform(x_train) x_train = scaler.transform(x_train)
@ -25,7 +36,7 @@ if __name__ == "__main__":
train_ds = pt.datasets.NumpyDataset(x_train, y_train) train_ds = pt.datasets.NumpyDataset(x_train, y_train)
# Dataloaders # Dataloaders
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=150) train_loader = DataLoader(train_ds, batch_size=150)
# Hyperparameters # Hyperparameters
hparams = dict( hparams = dict(
@ -35,7 +46,7 @@ if __name__ == "__main__":
) )
# Initialize the model # Initialize the model
model = pt.models.NeuralGas( model = NeuralGas(
hparams, hparams,
prototypes_initializer=pt.core.ZCI(2), prototypes_initializer=pt.core.ZCI(2),
lr_scheduler=ExponentialLR, lr_scheduler=ExponentialLR,
@ -45,17 +56,18 @@ if __name__ == "__main__":
# Compute intermediate input and output sizes # Compute intermediate input and output sizes
model.example_input_array = torch.zeros(4, 2) model.example_input_array = torch.zeros(4, 2)
# Model summary
print(model)
# Callbacks # Callbacks
vis = pt.models.VisNG2D(data=train_ds) vis = VisNG2D(data=train_ds)
# Setup trainer # Setup trainer
trainer = pl.Trainer.from_argparse_args( trainer = pl.Trainer.from_argparse_args(
args, args,
callbacks=[vis], callbacks=[
weights_summary="full", vis,
],
max_epochs=1000,
log_every_n_steps=1,
detect_anomaly=True,
) )
# Training loop # Training loop

View File

@ -1,10 +1,18 @@
"""RSLVQ example using the Iris dataset.""" """RSLVQ example using the Iris dataset."""
import argparse import argparse
import warnings
import prototorch as pt import prototorch as pt
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
from prototorch.models import RSLVQ, VisGLVQ2D
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.utilities.warnings import PossibleUserWarning
from torch.utils.data import DataLoader
warnings.filterwarnings("ignore", category=PossibleUserWarning)
warnings.filterwarnings("ignore", category=UserWarning)
if __name__ == "__main__": if __name__ == "__main__":
# Command-line arguments # Command-line arguments
@ -13,13 +21,13 @@ if __name__ == "__main__":
args = parser.parse_args() args = parser.parse_args()
# Reproducibility # Reproducibility
pl.utilities.seed.seed_everything(seed=42) seed_everything(seed=42)
# Dataset # Dataset
train_ds = pt.datasets.Iris(dims=[0, 2]) train_ds = pt.datasets.Iris(dims=[0, 2])
# Dataloaders # Dataloaders
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=64) train_loader = DataLoader(train_ds, batch_size=64)
# Hyperparameters # Hyperparameters
hparams = dict( hparams = dict(
@ -33,7 +41,7 @@ if __name__ == "__main__":
) )
# Initialize the model # Initialize the model
model = pt.models.RSLVQ( model = RSLVQ(
hparams, hparams,
optimizer=torch.optim.Adam, optimizer=torch.optim.Adam,
prototypes_initializer=pt.initializers.SSCI(train_ds, noise=0.2), prototypes_initializer=pt.initializers.SSCI(train_ds, noise=0.2),
@ -42,19 +50,18 @@ if __name__ == "__main__":
# Compute intermediate input and output sizes # Compute intermediate input and output sizes
model.example_input_array = torch.zeros(4, 2) model.example_input_array = torch.zeros(4, 2)
# Summary
print(model)
# Callbacks # Callbacks
vis = pt.models.VisGLVQ2D(data=train_ds) vis = VisGLVQ2D(data=train_ds)
# Setup trainer # Setup trainer
trainer = pl.Trainer.from_argparse_args( trainer = pl.Trainer.from_argparse_args(
args, args,
callbacks=[vis], callbacks=[
terminate_on_nan=True, vis,
weights_summary="full", ],
accelerator="ddp", detect_anomaly=True,
max_epochs=100,
log_every_n_steps=1,
) )
# Training loop # Training loop

View File

@ -1,10 +1,18 @@
"""Siamese GLVQ example using all four dimensions of the Iris dataset.""" """Siamese GLVQ example using all four dimensions of the Iris dataset."""
import argparse import argparse
import warnings
import prototorch as pt import prototorch as pt
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
from prototorch.models import SiameseGLVQ, VisSiameseGLVQ2D
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.utilities.warnings import PossibleUserWarning
from torch.utils.data import DataLoader
warnings.filterwarnings("ignore", category=PossibleUserWarning)
warnings.filterwarnings("ignore", category=UserWarning)
class Backbone(torch.nn.Module): class Backbone(torch.nn.Module):
@ -34,10 +42,10 @@ if __name__ == "__main__":
train_ds = pt.datasets.Iris() train_ds = pt.datasets.Iris()
# Reproducibility # Reproducibility
pl.utilities.seed.seed_everything(seed=2) seed_everything(seed=2)
# Dataloaders # Dataloaders
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=150) train_loader = DataLoader(train_ds, batch_size=150)
# Hyperparameters # Hyperparameters
hparams = dict( hparams = dict(
@ -50,23 +58,25 @@ if __name__ == "__main__":
backbone = Backbone() backbone = Backbone()
# Initialize the model # Initialize the model
model = pt.models.SiameseGLVQ( model = SiameseGLVQ(
hparams, hparams,
prototypes_initializer=pt.initializers.SMCI(train_ds), prototypes_initializer=pt.initializers.SMCI(train_ds),
backbone=backbone, backbone=backbone,
both_path_gradients=False, both_path_gradients=False,
) )
# Model summary
print(model)
# Callbacks # Callbacks
vis = pt.models.VisSiameseGLVQ2D(data=train_ds, border=0.1) vis = VisSiameseGLVQ2D(data=train_ds, border=0.1)
# Setup trainer # Setup trainer
trainer = pl.Trainer.from_argparse_args( trainer = pl.Trainer.from_argparse_args(
args, args,
callbacks=[vis], callbacks=[
vis,
],
max_epochs=1000,
log_every_n_steps=1,
detect_anomaly=True,
) )
# Training loop # Training loop

View File

@ -1,10 +1,18 @@
"""Siamese GTLVQ example using all four dimensions of the Iris dataset.""" """Siamese GTLVQ example using all four dimensions of the Iris dataset."""
import argparse import argparse
import warnings
import prototorch as pt import prototorch as pt
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
from prototorch.models import SiameseGTLVQ, VisSiameseGLVQ2D
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.utilities.warnings import PossibleUserWarning
from torch.utils.data import DataLoader
warnings.filterwarnings("ignore", category=PossibleUserWarning)
warnings.filterwarnings("ignore", category=UserWarning)
class Backbone(torch.nn.Module): class Backbone(torch.nn.Module):
@ -34,39 +42,43 @@ if __name__ == "__main__":
train_ds = pt.datasets.Iris() train_ds = pt.datasets.Iris()
# Reproducibility # Reproducibility
pl.utilities.seed.seed_everything(seed=2) seed_everything(seed=2)
# Dataloaders # Dataloaders
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=150) train_loader = DataLoader(train_ds, batch_size=150)
# Hyperparameters # Hyperparameters
hparams = dict(distribution=[1, 2, 3], hparams = dict(
proto_lr=0.01, distribution=[1, 2, 3],
bb_lr=0.01, proto_lr=0.01,
input_dim=2, bb_lr=0.01,
latent_dim=1) input_dim=2,
latent_dim=1,
)
# Initialize the backbone # Initialize the backbone
backbone = Backbone(latent_size=hparams["input_dim"]) backbone = Backbone(latent_size=hparams["input_dim"])
# Initialize the model # Initialize the model
model = pt.models.SiameseGTLVQ( model = SiameseGTLVQ(
hparams, hparams,
prototypes_initializer=pt.initializers.SMCI(train_ds), prototypes_initializer=pt.initializers.SMCI(train_ds),
backbone=backbone, backbone=backbone,
both_path_gradients=False, both_path_gradients=False,
) )
# Model summary
print(model)
# Callbacks # Callbacks
vis = pt.models.VisSiameseGLVQ2D(data=train_ds, border=0.1) vis = VisSiameseGLVQ2D(data=train_ds, border=0.1)
# Setup trainer # Setup trainer
trainer = pl.Trainer.from_argparse_args( trainer = pl.Trainer.from_argparse_args(
args, args,
callbacks=[vis], callbacks=[
vis,
],
max_epochs=1000,
log_every_n_steps=1,
detect_anomaly=True,
) )
# Training loop # Training loop

View File

@ -1,13 +1,30 @@
"""Warm-starting GLVQ with prototypes from Growing Neural Gas.""" """Warm-starting GLVQ with prototypes from Growing Neural Gas."""
import argparse import argparse
import warnings
import prototorch as pt import prototorch as pt
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
from prototorch.models import (
GLVQ,
KNN,
GrowingNeuralGas,
PruneLoserPrototypes,
VisGLVQ2D,
)
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.utilities.warnings import PossibleUserWarning
from torch.optim.lr_scheduler import ExponentialLR from torch.optim.lr_scheduler import ExponentialLR
from torch.utils.data import DataLoader
warnings.filterwarnings("ignore", category=PossibleUserWarning)
if __name__ == "__main__": if __name__ == "__main__":
# Reproducibility
seed_everything(seed=4)
# Command-line arguments # Command-line arguments
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser) parser = pl.Trainer.add_argparse_args(parser)
@ -15,10 +32,10 @@ if __name__ == "__main__":
# Prepare the data # Prepare the data
train_ds = pt.datasets.Iris(dims=[0, 2]) train_ds = pt.datasets.Iris(dims=[0, 2])
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=64) train_loader = DataLoader(train_ds, batch_size=64, num_workers=0)
# Initialize the gng # Initialize the gng
gng = pt.models.GrowingNeuralGas( gng = GrowingNeuralGas(
hparams=dict(num_prototypes=5, insert_freq=2, lr=0.1), hparams=dict(num_prototypes=5, insert_freq=2, lr=0.1),
prototypes_initializer=pt.initializers.ZCI(2), prototypes_initializer=pt.initializers.ZCI(2),
lr_scheduler=ExponentialLR, lr_scheduler=ExponentialLR,
@ -26,7 +43,7 @@ if __name__ == "__main__":
) )
# Callbacks # Callbacks
es = pl.callbacks.EarlyStopping( es = EarlyStopping(
monitor="loss", monitor="loss",
min_delta=0.001, min_delta=0.001,
patience=20, patience=20,
@ -37,9 +54,12 @@ if __name__ == "__main__":
# Setup trainer for GNG # Setup trainer for GNG
trainer = pl.Trainer( trainer = pl.Trainer(
max_epochs=100, max_epochs=1000,
callbacks=[es], callbacks=[
weights_summary=None, es,
],
log_every_n_steps=1,
detect_anomaly=True,
) )
# Training loop # Training loop
@ -52,12 +72,12 @@ if __name__ == "__main__":
) )
# Warm-start prototypes # Warm-start prototypes
knn = pt.models.KNN(dict(k=1), data=train_ds) knn = KNN(dict(k=1), data=train_ds)
prototypes = gng.prototypes prototypes = gng.prototypes
plabels = knn.predict(prototypes) plabels = knn.predict(prototypes)
# Initialize the model # Initialize the model
model = pt.models.GLVQ( model = GLVQ(
hparams, hparams,
optimizer=torch.optim.Adam, optimizer=torch.optim.Adam,
prototypes_initializer=pt.initializers.LCI(prototypes), prototypes_initializer=pt.initializers.LCI(prototypes),
@ -70,15 +90,15 @@ if __name__ == "__main__":
model.example_input_array = torch.zeros(4, 2) model.example_input_array = torch.zeros(4, 2)
# Callbacks # Callbacks
vis = pt.models.VisGLVQ2D(data=train_ds) vis = VisGLVQ2D(data=train_ds)
pruning = pt.models.PruneLoserPrototypes( pruning = PruneLoserPrototypes(
threshold=0.02, threshold=0.02,
idle_epochs=2, idle_epochs=2,
prune_quota_per_epoch=5, prune_quota_per_epoch=5,
frequency=1, frequency=1,
verbose=True, verbose=True,
) )
es = pl.callbacks.EarlyStopping( es = EarlyStopping(
monitor="train_loss", monitor="train_loss",
min_delta=0.001, min_delta=0.001,
patience=10, patience=10,
@ -95,8 +115,9 @@ if __name__ == "__main__":
pruning, pruning,
es, es,
], ],
weights_summary="full", max_epochs=1000,
accelerator="ddp", log_every_n_steps=1,
detect_anomaly=True,
) )
# Training loop # Training loop