Update example scripts

This commit is contained in:
Jensun Ravichandran 2021-06-04 22:21:28 +02:00
parent 016fcb4060
commit 5d2a8226ce
7 changed files with 123 additions and 42 deletions

View File

@ -30,7 +30,7 @@ if __name__ == "__main__":
prototypes_per_class = num_clusters * 5 prototypes_per_class = num_clusters * 5
hparams = dict( hparams = dict(
distribution=(num_classes, prototypes_per_class), distribution=(num_classes, prototypes_per_class),
lr=0.1, lr=0.2,
) )
# Initialize the model # Initialize the model
@ -39,6 +39,12 @@ if __name__ == "__main__":
prototype_initializer=pt.components.Ones(2, scale=3), prototype_initializer=pt.components.Ones(2, scale=3),
) )
# Compute intermediate input and output sizes
model.example_input_array = torch.zeros(4, 2)
# Summary
print(model)
# Callbacks # Callbacks
vis = pt.models.VisGLVQ2D(train_ds) vis = pt.models.VisGLVQ2D(train_ds)
pruning = pt.models.PruneLoserPrototypes( pruning = pt.models.PruneLoserPrototypes(
@ -67,7 +73,7 @@ if __name__ == "__main__":
], ],
progress_bar_refresh_rate=0, progress_bar_refresh_rate=0,
terminate_on_nan=True, terminate_on_nan=True,
weights_summary=None, weights_summary="full",
accelerator="ddp", accelerator="ddp",
) )

View File

@ -2,11 +2,10 @@
import argparse import argparse
import prototorch as pt
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
import prototorch as pt
if __name__ == "__main__": if __name__ == "__main__":
# Command-line arguments # Command-line arguments
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
@ -14,35 +13,63 @@ if __name__ == "__main__":
args = parser.parse_args() args = parser.parse_args()
# Dataset # Dataset
train_ds = pt.datasets.Spiral(num_samples=600, noise=0.6) train_ds = pt.datasets.Spiral(num_samples=500, noise=0.5)
# Dataloaders # Dataloaders
train_loader = torch.utils.data.DataLoader(train_ds, train_loader = torch.utils.data.DataLoader(train_ds, batch_size=256)
num_workers=0,
batch_size=256)
# Hyperparameters # Hyperparameters
num_classes = 2 num_classes = 2
prototypes_per_class = 20 prototypes_per_class = 10
hparams = dict( hparams = dict(
distribution=(num_classes, prototypes_per_class), distribution=(num_classes, prototypes_per_class),
transfer_function="sigmoid_beta", transfer_function="swish_beta",
transfer_beta=10.0, transfer_beta=10.0,
lr=0.01, # lr=0.1,
proto_lr=0.1,
bb_lr=0.1,
input_dim=2,
latent_dim=2,
) )
# Initialize the model # Initialize the model
model = pt.models.GLVQ(hparams, model = pt.models.GMLVQ(
prototype_initializer=pt.components.SSI(train_ds, hparams,
noise=1e-1)) optimizer=torch.optim.Adam,
prototype_initializer=pt.components.SSI(train_ds, noise=1e-2),
)
# Callbacks # Callbacks
vis = pt.models.VisGLVQ2D(train_ds, show_last_only=True, block=True) vis = pt.models.VisGLVQ2D(
train_ds,
show_last_only=False,
block=False,
)
pruning = pt.models.PruneLoserPrototypes(
threshold=0.02,
idle_epochs=10,
prune_quota_per_epoch=5,
frequency=2,
replace=True,
initializer=pt.components.SSI(train_ds, noise=1e-2),
verbose=True,
)
es = pl.callbacks.EarlyStopping(
monitor="train_loss",
min_delta=1.0,
patience=5,
mode="min",
check_on_train_epoch_end=True,
)
# Setup trainer # Setup trainer
trainer = pl.Trainer.from_argparse_args( trainer = pl.Trainer.from_argparse_args(
args, args,
callbacks=[vis], callbacks=[
vis,
# es,
pruning,
],
terminate_on_nan=True, terminate_on_nan=True,
) )

View File

@ -43,7 +43,7 @@ if __name__ == "__main__":
# Hyperparameters # Hyperparameters
num_classes = 10 num_classes = 10
prototypes_per_class = 2 prototypes_per_class = 10
hparams = dict( hparams = dict(
input_dim=28 * 28, input_dim=28 * 28,
latent_dim=28 * 28, latent_dim=28 * 28,
@ -62,19 +62,40 @@ if __name__ == "__main__":
# Callbacks # Callbacks
vis = pt.models.VisImgComp( vis = pt.models.VisImgComp(
data=train_ds, data=train_ds,
num_columns=5, num_columns=10,
show=False, show=False,
tensorboard=True, tensorboard=True,
random_data=20, random_data=100,
add_embedding=True, add_embedding=True,
embedding_data=100, embedding_data=200,
flatten_data=False, flatten_data=False,
) )
pruning = pt.models.PruneLoserPrototypes(
threshold=0.01,
idle_epochs=1,
prune_quota_per_epoch=10,
frequency=1,
verbose=True,
)
es = pl.callbacks.EarlyStopping(
monitor="train_loss",
min_delta=0.001,
patience=15,
mode="min",
check_on_train_epoch_end=True,
)
# Setup trainer # Setup trainer
trainer = pl.Trainer.from_argparse_args( trainer = pl.Trainer.from_argparse_args(
args, args,
callbacks=[vis], callbacks=[
vis,
pruning,
# es,
],
terminate_on_nan=True,
weights_summary=None,
accelerator="ddp",
) )
# Training loop # Training loop

View File

@ -4,10 +4,7 @@ import argparse
import prototorch as pt import prototorch as pt
import pytorch_lightning as pl import pytorch_lightning as pl
from prototorch.components.initializers import Zeros import torch
from prototorch.datasets import Iris
from prototorch.models.unsupervised import GrowingNeuralGas
from torch.utils.data import DataLoader
if __name__ == "__main__": if __name__ == "__main__":
# Command-line arguments # Command-line arguments
@ -19,8 +16,8 @@ if __name__ == "__main__":
pl.utilities.seed.seed_everything(seed=42) pl.utilities.seed.seed_everything(seed=42)
# Prepare the data # Prepare the data
train_ds = Iris(dims=[0, 2]) train_ds = pt.datasets.Iris(dims=[0, 2])
train_loader = DataLoader(train_ds, batch_size=8) train_loader = torch.utils.data.DataLoader(train_ds, batch_size=8)
# Hyperparameters # Hyperparameters
hparams = dict( hparams = dict(
@ -29,11 +26,14 @@ if __name__ == "__main__":
) )
# Initialize the model # Initialize the model
model = GrowingNeuralGas( model = pt.models.GrowingNeuralGas(
hparams, hparams,
prototype_initializer=Zeros(2), prototype_initializer=pt.components.Zeros(2),
) )
# Compute intermediate input and output sizes
model.example_input_array = torch.zeros(4, 2)
# Model summary # Model summary
print(model) print(model)
@ -45,6 +45,7 @@ if __name__ == "__main__":
args, args,
max_epochs=100, max_epochs=100,
callbacks=[vis], callbacks=[vis],
weights_summary="full",
) )
# Training loop # Training loop

View File

@ -1,13 +1,12 @@
"""k-NN example using the Iris dataset.""" """k-NN example using the Iris dataset from scikit-learn."""
import argparse import argparse
import prototorch as pt
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
from sklearn.datasets import load_iris from sklearn.datasets import load_iris
import prototorch as pt
if __name__ == "__main__": if __name__ == "__main__":
# Command-line arguments # Command-line arguments
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
@ -23,18 +22,30 @@ if __name__ == "__main__":
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=150) train_loader = torch.utils.data.DataLoader(train_ds, batch_size=150)
# Hyperparameters # Hyperparameters
hparams = dict(k=20) hparams = dict(k=5)
# Initialize the model # Initialize the model
model = pt.models.KNN(hparams, data=train_ds) model = pt.models.KNN(hparams, data=train_ds)
# Compute intermediate input and output sizes
model.example_input_array = torch.zeros(4, 2)
# Summary
print(model)
# Callbacks # Callbacks
vis = pt.models.VisGLVQ2D(data=(x_train, y_train), resolution=200) vis = pt.models.VisGLVQ2D(
data=(x_train, y_train),
resolution=200,
block=True,
)
# Setup trainer # Setup trainer
trainer = pl.Trainer.from_argparse_args( trainer = pl.Trainer.from_argparse_args(
args, args,
max_epochs=1,
callbacks=[vis], callbacks=[vis],
weights_summary="full",
) )
# Training loop # Training loop

View File

@ -7,6 +7,7 @@ import pytorch_lightning as pl
import torch import torch
from sklearn.datasets import load_iris from sklearn.datasets import load_iris
from sklearn.preprocessing import StandardScaler from sklearn.preprocessing import StandardScaler
from torch.optim.lr_scheduler import ExponentialLR
if __name__ == "__main__": if __name__ == "__main__":
# Command-line arguments # Command-line arguments
@ -30,8 +31,15 @@ if __name__ == "__main__":
hparams = dict(num_prototypes=30, lr=0.03) hparams = dict(num_prototypes=30, lr=0.03)
# Initialize the model # Initialize the model
model = pt.models.NeuralGas(hparams, model = pt.models.NeuralGas(
prototype_initializer=pt.components.Zeros(2)) hparams,
prototype_initializer=pt.components.Zeros(2),
lr_scheduler=ExponentialLR,
lr_scheduler_kwargs=dict(gamma=0.99, verbose=False),
)
# Compute intermediate input and output sizes
model.example_input_array = torch.zeros(4, 2)
# Model summary # Model summary
print(model) print(model)
@ -43,6 +51,7 @@ if __name__ == "__main__":
trainer = pl.Trainer.from_argparse_args( trainer = pl.Trainer.from_argparse_args(
args, args,
callbacks=[vis], callbacks=[vis],
weights_summary="full",
) )
# Training loop # Training loop

View File

@ -1,12 +1,11 @@
"""Probabilistic-LVQ example using the Iris dataset.""" """RSLVQ example using the Iris dataset."""
import argparse import argparse
import prototorch as pt
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
import prototorch as pt
if __name__ == "__main__": if __name__ == "__main__":
# Command-line arguments # Command-line arguments
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
@ -26,16 +25,23 @@ if __name__ == "__main__":
hparams = dict( hparams = dict(
distribution=[2, 2, 3], distribution=[2, 2, 3],
lr=0.05, lr=0.05,
variance=0.3, variance=0.1,
) )
# Initialize the model # Initialize the model
model = pt.models.probabilistic.RSLVQ( model = pt.models.probabilistic.RSLVQ(
hparams, hparams,
optimizer=torch.optim.Adam, optimizer=torch.optim.Adam,
# prototype_initializer=pt.components.SMI(train_ds),
prototype_initializer=pt.components.SSI(train_ds, noise=0.2), prototype_initializer=pt.components.SSI(train_ds, noise=0.2),
# prototype_initializer=pt.components.Zeros(2),
# prototype_initializer=pt.components.Ones(2, scale=2.0),
) )
# Compute intermediate input and output sizes
model.example_input_array = torch.zeros(4, 2)
# Summary
print(model) print(model)
# Callbacks # Callbacks
@ -46,8 +52,8 @@ if __name__ == "__main__":
args, args,
callbacks=[vis], callbacks=[vis],
terminate_on_nan=True, terminate_on_nan=True,
weights_summary=None, weights_summary="full",
# accelerator="ddp", accelerator="ddp",
) )
# Training loop # Training loop