Update example scripts
This commit is contained in:
parent
016fcb4060
commit
5d2a8226ce
@ -30,7 +30,7 @@ if __name__ == "__main__":
|
||||
prototypes_per_class = num_clusters * 5
|
||||
hparams = dict(
|
||||
distribution=(num_classes, prototypes_per_class),
|
||||
lr=0.1,
|
||||
lr=0.2,
|
||||
)
|
||||
|
||||
# Initialize the model
|
||||
@ -39,6 +39,12 @@ if __name__ == "__main__":
|
||||
prototype_initializer=pt.components.Ones(2, scale=3),
|
||||
)
|
||||
|
||||
# Compute intermediate input and output sizes
|
||||
model.example_input_array = torch.zeros(4, 2)
|
||||
|
||||
# Summary
|
||||
print(model)
|
||||
|
||||
# Callbacks
|
||||
vis = pt.models.VisGLVQ2D(train_ds)
|
||||
pruning = pt.models.PruneLoserPrototypes(
|
||||
@ -67,7 +73,7 @@ if __name__ == "__main__":
|
||||
],
|
||||
progress_bar_refresh_rate=0,
|
||||
terminate_on_nan=True,
|
||||
weights_summary=None,
|
||||
weights_summary="full",
|
||||
accelerator="ddp",
|
||||
)
|
||||
|
||||
|
@ -2,11 +2,10 @@
|
||||
|
||||
import argparse
|
||||
|
||||
import prototorch as pt
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
|
||||
import prototorch as pt
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Command-line arguments
|
||||
parser = argparse.ArgumentParser()
|
||||
@ -14,35 +13,63 @@ if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
|
||||
# Dataset
|
||||
train_ds = pt.datasets.Spiral(num_samples=600, noise=0.6)
|
||||
train_ds = pt.datasets.Spiral(num_samples=500, noise=0.5)
|
||||
|
||||
# Dataloaders
|
||||
train_loader = torch.utils.data.DataLoader(train_ds,
|
||||
num_workers=0,
|
||||
batch_size=256)
|
||||
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=256)
|
||||
|
||||
# Hyperparameters
|
||||
num_classes = 2
|
||||
prototypes_per_class = 20
|
||||
prototypes_per_class = 10
|
||||
hparams = dict(
|
||||
distribution=(num_classes, prototypes_per_class),
|
||||
transfer_function="sigmoid_beta",
|
||||
transfer_function="swish_beta",
|
||||
transfer_beta=10.0,
|
||||
lr=0.01,
|
||||
# lr=0.1,
|
||||
proto_lr=0.1,
|
||||
bb_lr=0.1,
|
||||
input_dim=2,
|
||||
latent_dim=2,
|
||||
)
|
||||
|
||||
# Initialize the model
|
||||
model = pt.models.GLVQ(hparams,
|
||||
prototype_initializer=pt.components.SSI(train_ds,
|
||||
noise=1e-1))
|
||||
model = pt.models.GMLVQ(
|
||||
hparams,
|
||||
optimizer=torch.optim.Adam,
|
||||
prototype_initializer=pt.components.SSI(train_ds, noise=1e-2),
|
||||
)
|
||||
|
||||
# Callbacks
|
||||
vis = pt.models.VisGLVQ2D(train_ds, show_last_only=True, block=True)
|
||||
vis = pt.models.VisGLVQ2D(
|
||||
train_ds,
|
||||
show_last_only=False,
|
||||
block=False,
|
||||
)
|
||||
pruning = pt.models.PruneLoserPrototypes(
|
||||
threshold=0.02,
|
||||
idle_epochs=10,
|
||||
prune_quota_per_epoch=5,
|
||||
frequency=2,
|
||||
replace=True,
|
||||
initializer=pt.components.SSI(train_ds, noise=1e-2),
|
||||
verbose=True,
|
||||
)
|
||||
es = pl.callbacks.EarlyStopping(
|
||||
monitor="train_loss",
|
||||
min_delta=1.0,
|
||||
patience=5,
|
||||
mode="min",
|
||||
check_on_train_epoch_end=True,
|
||||
)
|
||||
|
||||
# Setup trainer
|
||||
trainer = pl.Trainer.from_argparse_args(
|
||||
args,
|
||||
callbacks=[vis],
|
||||
callbacks=[
|
||||
vis,
|
||||
# es,
|
||||
pruning,
|
||||
],
|
||||
terminate_on_nan=True,
|
||||
)
|
||||
|
||||
|
@ -43,7 +43,7 @@ if __name__ == "__main__":
|
||||
|
||||
# Hyperparameters
|
||||
num_classes = 10
|
||||
prototypes_per_class = 2
|
||||
prototypes_per_class = 10
|
||||
hparams = dict(
|
||||
input_dim=28 * 28,
|
||||
latent_dim=28 * 28,
|
||||
@ -62,19 +62,40 @@ if __name__ == "__main__":
|
||||
# Callbacks
|
||||
vis = pt.models.VisImgComp(
|
||||
data=train_ds,
|
||||
num_columns=5,
|
||||
num_columns=10,
|
||||
show=False,
|
||||
tensorboard=True,
|
||||
random_data=20,
|
||||
random_data=100,
|
||||
add_embedding=True,
|
||||
embedding_data=100,
|
||||
embedding_data=200,
|
||||
flatten_data=False,
|
||||
)
|
||||
pruning = pt.models.PruneLoserPrototypes(
|
||||
threshold=0.01,
|
||||
idle_epochs=1,
|
||||
prune_quota_per_epoch=10,
|
||||
frequency=1,
|
||||
verbose=True,
|
||||
)
|
||||
es = pl.callbacks.EarlyStopping(
|
||||
monitor="train_loss",
|
||||
min_delta=0.001,
|
||||
patience=15,
|
||||
mode="min",
|
||||
check_on_train_epoch_end=True,
|
||||
)
|
||||
|
||||
# Setup trainer
|
||||
trainer = pl.Trainer.from_argparse_args(
|
||||
args,
|
||||
callbacks=[vis],
|
||||
callbacks=[
|
||||
vis,
|
||||
pruning,
|
||||
# es,
|
||||
],
|
||||
terminate_on_nan=True,
|
||||
weights_summary=None,
|
||||
accelerator="ddp",
|
||||
)
|
||||
|
||||
# Training loop
|
||||
|
@ -4,10 +4,7 @@ import argparse
|
||||
|
||||
import prototorch as pt
|
||||
import pytorch_lightning as pl
|
||||
from prototorch.components.initializers import Zeros
|
||||
from prototorch.datasets import Iris
|
||||
from prototorch.models.unsupervised import GrowingNeuralGas
|
||||
from torch.utils.data import DataLoader
|
||||
import torch
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Command-line arguments
|
||||
@ -19,8 +16,8 @@ if __name__ == "__main__":
|
||||
pl.utilities.seed.seed_everything(seed=42)
|
||||
|
||||
# Prepare the data
|
||||
train_ds = Iris(dims=[0, 2])
|
||||
train_loader = DataLoader(train_ds, batch_size=8)
|
||||
train_ds = pt.datasets.Iris(dims=[0, 2])
|
||||
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=8)
|
||||
|
||||
# Hyperparameters
|
||||
hparams = dict(
|
||||
@ -29,11 +26,14 @@ if __name__ == "__main__":
|
||||
)
|
||||
|
||||
# Initialize the model
|
||||
model = GrowingNeuralGas(
|
||||
model = pt.models.GrowingNeuralGas(
|
||||
hparams,
|
||||
prototype_initializer=Zeros(2),
|
||||
prototype_initializer=pt.components.Zeros(2),
|
||||
)
|
||||
|
||||
# Compute intermediate input and output sizes
|
||||
model.example_input_array = torch.zeros(4, 2)
|
||||
|
||||
# Model summary
|
||||
print(model)
|
||||
|
||||
@ -45,6 +45,7 @@ if __name__ == "__main__":
|
||||
args,
|
||||
max_epochs=100,
|
||||
callbacks=[vis],
|
||||
weights_summary="full",
|
||||
)
|
||||
|
||||
# Training loop
|
||||
|
@ -1,13 +1,12 @@
|
||||
"""k-NN example using the Iris dataset."""
|
||||
"""k-NN example using the Iris dataset from scikit-learn."""
|
||||
|
||||
import argparse
|
||||
|
||||
import prototorch as pt
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
from sklearn.datasets import load_iris
|
||||
|
||||
import prototorch as pt
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Command-line arguments
|
||||
parser = argparse.ArgumentParser()
|
||||
@ -23,18 +22,30 @@ if __name__ == "__main__":
|
||||
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=150)
|
||||
|
||||
# Hyperparameters
|
||||
hparams = dict(k=20)
|
||||
hparams = dict(k=5)
|
||||
|
||||
# Initialize the model
|
||||
model = pt.models.KNN(hparams, data=train_ds)
|
||||
|
||||
# Compute intermediate input and output sizes
|
||||
model.example_input_array = torch.zeros(4, 2)
|
||||
|
||||
# Summary
|
||||
print(model)
|
||||
|
||||
# Callbacks
|
||||
vis = pt.models.VisGLVQ2D(data=(x_train, y_train), resolution=200)
|
||||
vis = pt.models.VisGLVQ2D(
|
||||
data=(x_train, y_train),
|
||||
resolution=200,
|
||||
block=True,
|
||||
)
|
||||
|
||||
# Setup trainer
|
||||
trainer = pl.Trainer.from_argparse_args(
|
||||
args,
|
||||
max_epochs=1,
|
||||
callbacks=[vis],
|
||||
weights_summary="full",
|
||||
)
|
||||
|
||||
# Training loop
|
||||
|
@ -7,6 +7,7 @@ import pytorch_lightning as pl
|
||||
import torch
|
||||
from sklearn.datasets import load_iris
|
||||
from sklearn.preprocessing import StandardScaler
|
||||
from torch.optim.lr_scheduler import ExponentialLR
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Command-line arguments
|
||||
@ -30,8 +31,15 @@ if __name__ == "__main__":
|
||||
hparams = dict(num_prototypes=30, lr=0.03)
|
||||
|
||||
# Initialize the model
|
||||
model = pt.models.NeuralGas(hparams,
|
||||
prototype_initializer=pt.components.Zeros(2))
|
||||
model = pt.models.NeuralGas(
|
||||
hparams,
|
||||
prototype_initializer=pt.components.Zeros(2),
|
||||
lr_scheduler=ExponentialLR,
|
||||
lr_scheduler_kwargs=dict(gamma=0.99, verbose=False),
|
||||
)
|
||||
|
||||
# Compute intermediate input and output sizes
|
||||
model.example_input_array = torch.zeros(4, 2)
|
||||
|
||||
# Model summary
|
||||
print(model)
|
||||
@ -43,6 +51,7 @@ if __name__ == "__main__":
|
||||
trainer = pl.Trainer.from_argparse_args(
|
||||
args,
|
||||
callbacks=[vis],
|
||||
weights_summary="full",
|
||||
)
|
||||
|
||||
# Training loop
|
||||
|
@ -1,12 +1,11 @@
|
||||
"""Probabilistic-LVQ example using the Iris dataset."""
|
||||
"""RSLVQ example using the Iris dataset."""
|
||||
|
||||
import argparse
|
||||
|
||||
import prototorch as pt
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
|
||||
import prototorch as pt
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Command-line arguments
|
||||
parser = argparse.ArgumentParser()
|
||||
@ -26,16 +25,23 @@ if __name__ == "__main__":
|
||||
hparams = dict(
|
||||
distribution=[2, 2, 3],
|
||||
lr=0.05,
|
||||
variance=0.3,
|
||||
variance=0.1,
|
||||
)
|
||||
|
||||
# Initialize the model
|
||||
model = pt.models.probabilistic.RSLVQ(
|
||||
hparams,
|
||||
optimizer=torch.optim.Adam,
|
||||
# prototype_initializer=pt.components.SMI(train_ds),
|
||||
prototype_initializer=pt.components.SSI(train_ds, noise=0.2),
|
||||
# prototype_initializer=pt.components.Zeros(2),
|
||||
# prototype_initializer=pt.components.Ones(2, scale=2.0),
|
||||
)
|
||||
|
||||
# Compute intermediate input and output sizes
|
||||
model.example_input_array = torch.zeros(4, 2)
|
||||
|
||||
# Summary
|
||||
print(model)
|
||||
|
||||
# Callbacks
|
||||
@ -46,8 +52,8 @@ if __name__ == "__main__":
|
||||
args,
|
||||
callbacks=[vis],
|
||||
terminate_on_nan=True,
|
||||
weights_summary=None,
|
||||
# accelerator="ddp",
|
||||
weights_summary="full",
|
||||
accelerator="ddp",
|
||||
)
|
||||
|
||||
# Training loop
|
||||
|
Loading…
Reference in New Issue
Block a user