Update example scripts
This commit is contained in:
parent
016fcb4060
commit
5d2a8226ce
@ -30,7 +30,7 @@ if __name__ == "__main__":
|
|||||||
prototypes_per_class = num_clusters * 5
|
prototypes_per_class = num_clusters * 5
|
||||||
hparams = dict(
|
hparams = dict(
|
||||||
distribution=(num_classes, prototypes_per_class),
|
distribution=(num_classes, prototypes_per_class),
|
||||||
lr=0.1,
|
lr=0.2,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Initialize the model
|
# Initialize the model
|
||||||
@ -39,6 +39,12 @@ if __name__ == "__main__":
|
|||||||
prototype_initializer=pt.components.Ones(2, scale=3),
|
prototype_initializer=pt.components.Ones(2, scale=3),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Compute intermediate input and output sizes
|
||||||
|
model.example_input_array = torch.zeros(4, 2)
|
||||||
|
|
||||||
|
# Summary
|
||||||
|
print(model)
|
||||||
|
|
||||||
# Callbacks
|
# Callbacks
|
||||||
vis = pt.models.VisGLVQ2D(train_ds)
|
vis = pt.models.VisGLVQ2D(train_ds)
|
||||||
pruning = pt.models.PruneLoserPrototypes(
|
pruning = pt.models.PruneLoserPrototypes(
|
||||||
@ -67,7 +73,7 @@ if __name__ == "__main__":
|
|||||||
],
|
],
|
||||||
progress_bar_refresh_rate=0,
|
progress_bar_refresh_rate=0,
|
||||||
terminate_on_nan=True,
|
terminate_on_nan=True,
|
||||||
weights_summary=None,
|
weights_summary="full",
|
||||||
accelerator="ddp",
|
accelerator="ddp",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -2,11 +2,10 @@
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
|
import prototorch as pt
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
import prototorch as pt
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# Command-line arguments
|
# Command-line arguments
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
@ -14,35 +13,63 @@ if __name__ == "__main__":
|
|||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Dataset
|
# Dataset
|
||||||
train_ds = pt.datasets.Spiral(num_samples=600, noise=0.6)
|
train_ds = pt.datasets.Spiral(num_samples=500, noise=0.5)
|
||||||
|
|
||||||
# Dataloaders
|
# Dataloaders
|
||||||
train_loader = torch.utils.data.DataLoader(train_ds,
|
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=256)
|
||||||
num_workers=0,
|
|
||||||
batch_size=256)
|
|
||||||
|
|
||||||
# Hyperparameters
|
# Hyperparameters
|
||||||
num_classes = 2
|
num_classes = 2
|
||||||
prototypes_per_class = 20
|
prototypes_per_class = 10
|
||||||
hparams = dict(
|
hparams = dict(
|
||||||
distribution=(num_classes, prototypes_per_class),
|
distribution=(num_classes, prototypes_per_class),
|
||||||
transfer_function="sigmoid_beta",
|
transfer_function="swish_beta",
|
||||||
transfer_beta=10.0,
|
transfer_beta=10.0,
|
||||||
lr=0.01,
|
# lr=0.1,
|
||||||
|
proto_lr=0.1,
|
||||||
|
bb_lr=0.1,
|
||||||
|
input_dim=2,
|
||||||
|
latent_dim=2,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Initialize the model
|
# Initialize the model
|
||||||
model = pt.models.GLVQ(hparams,
|
model = pt.models.GMLVQ(
|
||||||
prototype_initializer=pt.components.SSI(train_ds,
|
hparams,
|
||||||
noise=1e-1))
|
optimizer=torch.optim.Adam,
|
||||||
|
prototype_initializer=pt.components.SSI(train_ds, noise=1e-2),
|
||||||
|
)
|
||||||
|
|
||||||
# Callbacks
|
# Callbacks
|
||||||
vis = pt.models.VisGLVQ2D(train_ds, show_last_only=True, block=True)
|
vis = pt.models.VisGLVQ2D(
|
||||||
|
train_ds,
|
||||||
|
show_last_only=False,
|
||||||
|
block=False,
|
||||||
|
)
|
||||||
|
pruning = pt.models.PruneLoserPrototypes(
|
||||||
|
threshold=0.02,
|
||||||
|
idle_epochs=10,
|
||||||
|
prune_quota_per_epoch=5,
|
||||||
|
frequency=2,
|
||||||
|
replace=True,
|
||||||
|
initializer=pt.components.SSI(train_ds, noise=1e-2),
|
||||||
|
verbose=True,
|
||||||
|
)
|
||||||
|
es = pl.callbacks.EarlyStopping(
|
||||||
|
monitor="train_loss",
|
||||||
|
min_delta=1.0,
|
||||||
|
patience=5,
|
||||||
|
mode="min",
|
||||||
|
check_on_train_epoch_end=True,
|
||||||
|
)
|
||||||
|
|
||||||
# Setup trainer
|
# Setup trainer
|
||||||
trainer = pl.Trainer.from_argparse_args(
|
trainer = pl.Trainer.from_argparse_args(
|
||||||
args,
|
args,
|
||||||
callbacks=[vis],
|
callbacks=[
|
||||||
|
vis,
|
||||||
|
# es,
|
||||||
|
pruning,
|
||||||
|
],
|
||||||
terminate_on_nan=True,
|
terminate_on_nan=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -43,7 +43,7 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
# Hyperparameters
|
# Hyperparameters
|
||||||
num_classes = 10
|
num_classes = 10
|
||||||
prototypes_per_class = 2
|
prototypes_per_class = 10
|
||||||
hparams = dict(
|
hparams = dict(
|
||||||
input_dim=28 * 28,
|
input_dim=28 * 28,
|
||||||
latent_dim=28 * 28,
|
latent_dim=28 * 28,
|
||||||
@ -62,19 +62,40 @@ if __name__ == "__main__":
|
|||||||
# Callbacks
|
# Callbacks
|
||||||
vis = pt.models.VisImgComp(
|
vis = pt.models.VisImgComp(
|
||||||
data=train_ds,
|
data=train_ds,
|
||||||
num_columns=5,
|
num_columns=10,
|
||||||
show=False,
|
show=False,
|
||||||
tensorboard=True,
|
tensorboard=True,
|
||||||
random_data=20,
|
random_data=100,
|
||||||
add_embedding=True,
|
add_embedding=True,
|
||||||
embedding_data=100,
|
embedding_data=200,
|
||||||
flatten_data=False,
|
flatten_data=False,
|
||||||
)
|
)
|
||||||
|
pruning = pt.models.PruneLoserPrototypes(
|
||||||
|
threshold=0.01,
|
||||||
|
idle_epochs=1,
|
||||||
|
prune_quota_per_epoch=10,
|
||||||
|
frequency=1,
|
||||||
|
verbose=True,
|
||||||
|
)
|
||||||
|
es = pl.callbacks.EarlyStopping(
|
||||||
|
monitor="train_loss",
|
||||||
|
min_delta=0.001,
|
||||||
|
patience=15,
|
||||||
|
mode="min",
|
||||||
|
check_on_train_epoch_end=True,
|
||||||
|
)
|
||||||
|
|
||||||
# Setup trainer
|
# Setup trainer
|
||||||
trainer = pl.Trainer.from_argparse_args(
|
trainer = pl.Trainer.from_argparse_args(
|
||||||
args,
|
args,
|
||||||
callbacks=[vis],
|
callbacks=[
|
||||||
|
vis,
|
||||||
|
pruning,
|
||||||
|
# es,
|
||||||
|
],
|
||||||
|
terminate_on_nan=True,
|
||||||
|
weights_summary=None,
|
||||||
|
accelerator="ddp",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Training loop
|
# Training loop
|
||||||
|
@ -4,10 +4,7 @@ import argparse
|
|||||||
|
|
||||||
import prototorch as pt
|
import prototorch as pt
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
from prototorch.components.initializers import Zeros
|
import torch
|
||||||
from prototorch.datasets import Iris
|
|
||||||
from prototorch.models.unsupervised import GrowingNeuralGas
|
|
||||||
from torch.utils.data import DataLoader
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# Command-line arguments
|
# Command-line arguments
|
||||||
@ -19,8 +16,8 @@ if __name__ == "__main__":
|
|||||||
pl.utilities.seed.seed_everything(seed=42)
|
pl.utilities.seed.seed_everything(seed=42)
|
||||||
|
|
||||||
# Prepare the data
|
# Prepare the data
|
||||||
train_ds = Iris(dims=[0, 2])
|
train_ds = pt.datasets.Iris(dims=[0, 2])
|
||||||
train_loader = DataLoader(train_ds, batch_size=8)
|
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=8)
|
||||||
|
|
||||||
# Hyperparameters
|
# Hyperparameters
|
||||||
hparams = dict(
|
hparams = dict(
|
||||||
@ -29,11 +26,14 @@ if __name__ == "__main__":
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Initialize the model
|
# Initialize the model
|
||||||
model = GrowingNeuralGas(
|
model = pt.models.GrowingNeuralGas(
|
||||||
hparams,
|
hparams,
|
||||||
prototype_initializer=Zeros(2),
|
prototype_initializer=pt.components.Zeros(2),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Compute intermediate input and output sizes
|
||||||
|
model.example_input_array = torch.zeros(4, 2)
|
||||||
|
|
||||||
# Model summary
|
# Model summary
|
||||||
print(model)
|
print(model)
|
||||||
|
|
||||||
@ -45,6 +45,7 @@ if __name__ == "__main__":
|
|||||||
args,
|
args,
|
||||||
max_epochs=100,
|
max_epochs=100,
|
||||||
callbacks=[vis],
|
callbacks=[vis],
|
||||||
|
weights_summary="full",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Training loop
|
# Training loop
|
||||||
|
@ -1,13 +1,12 @@
|
|||||||
"""k-NN example using the Iris dataset."""
|
"""k-NN example using the Iris dataset from scikit-learn."""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
|
import prototorch as pt
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import torch
|
import torch
|
||||||
from sklearn.datasets import load_iris
|
from sklearn.datasets import load_iris
|
||||||
|
|
||||||
import prototorch as pt
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# Command-line arguments
|
# Command-line arguments
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
@ -23,18 +22,30 @@ if __name__ == "__main__":
|
|||||||
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=150)
|
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=150)
|
||||||
|
|
||||||
# Hyperparameters
|
# Hyperparameters
|
||||||
hparams = dict(k=20)
|
hparams = dict(k=5)
|
||||||
|
|
||||||
# Initialize the model
|
# Initialize the model
|
||||||
model = pt.models.KNN(hparams, data=train_ds)
|
model = pt.models.KNN(hparams, data=train_ds)
|
||||||
|
|
||||||
|
# Compute intermediate input and output sizes
|
||||||
|
model.example_input_array = torch.zeros(4, 2)
|
||||||
|
|
||||||
|
# Summary
|
||||||
|
print(model)
|
||||||
|
|
||||||
# Callbacks
|
# Callbacks
|
||||||
vis = pt.models.VisGLVQ2D(data=(x_train, y_train), resolution=200)
|
vis = pt.models.VisGLVQ2D(
|
||||||
|
data=(x_train, y_train),
|
||||||
|
resolution=200,
|
||||||
|
block=True,
|
||||||
|
)
|
||||||
|
|
||||||
# Setup trainer
|
# Setup trainer
|
||||||
trainer = pl.Trainer.from_argparse_args(
|
trainer = pl.Trainer.from_argparse_args(
|
||||||
args,
|
args,
|
||||||
|
max_epochs=1,
|
||||||
callbacks=[vis],
|
callbacks=[vis],
|
||||||
|
weights_summary="full",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Training loop
|
# Training loop
|
||||||
|
@ -7,6 +7,7 @@ import pytorch_lightning as pl
|
|||||||
import torch
|
import torch
|
||||||
from sklearn.datasets import load_iris
|
from sklearn.datasets import load_iris
|
||||||
from sklearn.preprocessing import StandardScaler
|
from sklearn.preprocessing import StandardScaler
|
||||||
|
from torch.optim.lr_scheduler import ExponentialLR
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# Command-line arguments
|
# Command-line arguments
|
||||||
@ -30,8 +31,15 @@ if __name__ == "__main__":
|
|||||||
hparams = dict(num_prototypes=30, lr=0.03)
|
hparams = dict(num_prototypes=30, lr=0.03)
|
||||||
|
|
||||||
# Initialize the model
|
# Initialize the model
|
||||||
model = pt.models.NeuralGas(hparams,
|
model = pt.models.NeuralGas(
|
||||||
prototype_initializer=pt.components.Zeros(2))
|
hparams,
|
||||||
|
prototype_initializer=pt.components.Zeros(2),
|
||||||
|
lr_scheduler=ExponentialLR,
|
||||||
|
lr_scheduler_kwargs=dict(gamma=0.99, verbose=False),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Compute intermediate input and output sizes
|
||||||
|
model.example_input_array = torch.zeros(4, 2)
|
||||||
|
|
||||||
# Model summary
|
# Model summary
|
||||||
print(model)
|
print(model)
|
||||||
@ -43,6 +51,7 @@ if __name__ == "__main__":
|
|||||||
trainer = pl.Trainer.from_argparse_args(
|
trainer = pl.Trainer.from_argparse_args(
|
||||||
args,
|
args,
|
||||||
callbacks=[vis],
|
callbacks=[vis],
|
||||||
|
weights_summary="full",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Training loop
|
# Training loop
|
||||||
|
@ -1,12 +1,11 @@
|
|||||||
"""Probabilistic-LVQ example using the Iris dataset."""
|
"""RSLVQ example using the Iris dataset."""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
|
import prototorch as pt
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
import prototorch as pt
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# Command-line arguments
|
# Command-line arguments
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
@ -26,16 +25,23 @@ if __name__ == "__main__":
|
|||||||
hparams = dict(
|
hparams = dict(
|
||||||
distribution=[2, 2, 3],
|
distribution=[2, 2, 3],
|
||||||
lr=0.05,
|
lr=0.05,
|
||||||
variance=0.3,
|
variance=0.1,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Initialize the model
|
# Initialize the model
|
||||||
model = pt.models.probabilistic.RSLVQ(
|
model = pt.models.probabilistic.RSLVQ(
|
||||||
hparams,
|
hparams,
|
||||||
optimizer=torch.optim.Adam,
|
optimizer=torch.optim.Adam,
|
||||||
|
# prototype_initializer=pt.components.SMI(train_ds),
|
||||||
prototype_initializer=pt.components.SSI(train_ds, noise=0.2),
|
prototype_initializer=pt.components.SSI(train_ds, noise=0.2),
|
||||||
|
# prototype_initializer=pt.components.Zeros(2),
|
||||||
|
# prototype_initializer=pt.components.Ones(2, scale=2.0),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Compute intermediate input and output sizes
|
||||||
|
model.example_input_array = torch.zeros(4, 2)
|
||||||
|
|
||||||
|
# Summary
|
||||||
print(model)
|
print(model)
|
||||||
|
|
||||||
# Callbacks
|
# Callbacks
|
||||||
@ -46,8 +52,8 @@ if __name__ == "__main__":
|
|||||||
args,
|
args,
|
||||||
callbacks=[vis],
|
callbacks=[vis],
|
||||||
terminate_on_nan=True,
|
terminate_on_nan=True,
|
||||||
weights_summary=None,
|
weights_summary="full",
|
||||||
# accelerator="ddp",
|
accelerator="ddp",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Training loop
|
# Training loop
|
||||||
|
Loading…
Reference in New Issue
Block a user