refactor(api)!: merge the new api changes into dev
This commit is contained in:
@@ -2,11 +2,10 @@
|
||||
|
||||
import argparse
|
||||
|
||||
import prototorch as pt
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
|
||||
import prototorch as pt
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Command-line arguments
|
||||
parser = argparse.ArgumentParser()
|
||||
@@ -24,14 +23,18 @@ if __name__ == "__main__":
|
||||
|
||||
# Hyperparameters
|
||||
hparams = dict(
|
||||
distribution=[2, 2, 2],
|
||||
proto_lr=0.1,
|
||||
distribution=[1, 0, 3],
|
||||
margin=0.1,
|
||||
proto_lr=0.01,
|
||||
bb_lr=0.01,
|
||||
)
|
||||
|
||||
# Initialize the model
|
||||
model = pt.models.CBC(
|
||||
hparams,
|
||||
prototype_initializer=pt.components.SSI(train_ds, noise=0.01),
|
||||
components_initializer=pt.initializers.SSCI(train_ds, noise=0.01),
|
||||
reasonings_iniitializer=pt.initializers.
|
||||
PurePositiveReasoningsInitializer(),
|
||||
)
|
||||
|
||||
# Callbacks
|
||||
|
@@ -2,11 +2,10 @@
|
||||
|
||||
import argparse
|
||||
|
||||
import prototorch as pt
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
|
||||
import prototorch as pt
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Command-line arguments
|
||||
parser = argparse.ArgumentParser()
|
||||
@@ -37,7 +36,7 @@ if __name__ == "__main__":
|
||||
# Initialize the model
|
||||
model = pt.models.CELVQ(
|
||||
hparams,
|
||||
prototype_initializer=pt.components.Ones(2, scale=3),
|
||||
prototypes_initializer=pt.initializers.FVCI(2, 3.0),
|
||||
)
|
||||
|
||||
# Compute intermediate input and output sizes
|
||||
|
@@ -2,12 +2,11 @@
|
||||
|
||||
import argparse
|
||||
|
||||
import prototorch as pt
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
from torch.optim.lr_scheduler import ExponentialLR
|
||||
|
||||
import prototorch as pt
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Command-line arguments
|
||||
parser = argparse.ArgumentParser()
|
||||
@@ -24,7 +23,7 @@ if __name__ == "__main__":
|
||||
hparams = dict(
|
||||
distribution={
|
||||
"num_classes": 3,
|
||||
"prototypes_per_class": 4
|
||||
"per_class": 4
|
||||
},
|
||||
lr=0.01,
|
||||
)
|
||||
@@ -33,7 +32,7 @@ if __name__ == "__main__":
|
||||
model = pt.models.GLVQ(
|
||||
hparams,
|
||||
optimizer=torch.optim.Adam,
|
||||
prototype_initializer=pt.components.SMI(train_ds),
|
||||
prototypes_initializer=pt.initializers.SMCI(train_ds),
|
||||
lr_scheduler=ExponentialLR,
|
||||
lr_scheduler_kwargs=dict(gamma=0.99, verbose=False),
|
||||
)
|
||||
|
@@ -2,11 +2,10 @@
|
||||
|
||||
import argparse
|
||||
|
||||
import prototorch as pt
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
|
||||
import prototorch as pt
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Command-line arguments
|
||||
parser = argparse.ArgumentParser()
|
||||
@@ -26,7 +25,6 @@ if __name__ == "__main__":
|
||||
distribution=(num_classes, prototypes_per_class),
|
||||
transfer_function="swish_beta",
|
||||
transfer_beta=10.0,
|
||||
# lr=0.1,
|
||||
proto_lr=0.1,
|
||||
bb_lr=0.1,
|
||||
input_dim=2,
|
||||
@@ -37,7 +35,7 @@ if __name__ == "__main__":
|
||||
model = pt.models.GMLVQ(
|
||||
hparams,
|
||||
optimizer=torch.optim.Adam,
|
||||
prototype_initializer=pt.components.SSI(train_ds, noise=1e-2),
|
||||
prototypes_initializer=pt.initializers.SSCI(train_ds, noise=1e-2),
|
||||
)
|
||||
|
||||
# Callbacks
|
||||
@@ -47,12 +45,12 @@ if __name__ == "__main__":
|
||||
block=False,
|
||||
)
|
||||
pruning = pt.models.PruneLoserPrototypes(
|
||||
threshold=0.02,
|
||||
threshold=0.01,
|
||||
idle_epochs=10,
|
||||
prune_quota_per_epoch=5,
|
||||
frequency=2,
|
||||
frequency=5,
|
||||
replace=True,
|
||||
initializer=pt.components.SSI(train_ds, noise=1e-2),
|
||||
prototypes_initializer=pt.initializers.SSCI(train_ds, noise=1e-1),
|
||||
verbose=True,
|
||||
)
|
||||
es = pl.callbacks.EarlyStopping(
|
||||
@@ -68,7 +66,7 @@ if __name__ == "__main__":
|
||||
args,
|
||||
callbacks=[
|
||||
vis,
|
||||
# es,
|
||||
# es, # FIXME
|
||||
pruning,
|
||||
],
|
||||
terminate_on_nan=True,
|
||||
|
@@ -1,59 +0,0 @@
|
||||
"""GLVQ example using the Iris dataset."""
|
||||
|
||||
import argparse
|
||||
|
||||
import prototorch as pt
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
from torch.optim.lr_scheduler import ExponentialLR
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Command-line arguments
|
||||
parser = argparse.ArgumentParser()
|
||||
parser = pl.Trainer.add_argparse_args(parser)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Dataset
|
||||
train_ds = pt.datasets.Iris()
|
||||
|
||||
# Dataloaders
|
||||
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=64)
|
||||
|
||||
# Hyperparameters
|
||||
hparams = dict(
|
||||
input_dim=4,
|
||||
latent_dim=3,
|
||||
distribution={
|
||||
"num_classes": 3,
|
||||
"prototypes_per_class": 2
|
||||
},
|
||||
proto_lr=0.0005,
|
||||
bb_lr=0.0005,
|
||||
)
|
||||
|
||||
# Initialize the model
|
||||
model = pt.models.GMLVQ(
|
||||
hparams,
|
||||
optimizer=torch.optim.Adam,
|
||||
prototype_initializer=pt.components.SSI(train_ds),
|
||||
lr_scheduler=ExponentialLR,
|
||||
lr_scheduler_kwargs=dict(gamma=0.99, verbose=False),
|
||||
omega_initializer=pt.components.PCA(train_ds.data)
|
||||
)
|
||||
|
||||
# Compute intermediate input and output sizes
|
||||
#model.example_input_array = torch.zeros(4, 2)
|
||||
|
||||
# Callbacks
|
||||
vis = pt.models.VisGMLVQ2D(data=train_ds, border=0.1)
|
||||
|
||||
# Setup trainer
|
||||
trainer = pl.Trainer.from_argparse_args(
|
||||
args,
|
||||
callbacks=[vis],
|
||||
weights_summary="full",
|
||||
accelerator="ddp",
|
||||
)
|
||||
|
||||
# Training loop
|
||||
trainer.fit(model, train_loader)
|
@@ -2,11 +2,10 @@
|
||||
|
||||
import argparse
|
||||
|
||||
import prototorch as pt
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
|
||||
import prototorch as pt
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Command-line arguments
|
||||
parser = argparse.ArgumentParser()
|
||||
@@ -30,7 +29,7 @@ if __name__ == "__main__":
|
||||
# Initialize the model
|
||||
model = pt.models.GrowingNeuralGas(
|
||||
hparams,
|
||||
prototype_initializer=pt.components.Zeros(2),
|
||||
prototypes_initializer=pt.initializers.ZCI(2),
|
||||
)
|
||||
|
||||
# Compute intermediate input and output sizes
|
||||
|
@@ -2,25 +2,11 @@
|
||||
|
||||
import argparse
|
||||
|
||||
import prototorch as pt
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
from matplotlib import pyplot as plt
|
||||
|
||||
import prototorch as pt
|
||||
|
||||
|
||||
def hex_to_rgb(hex_values):
|
||||
for v in hex_values:
|
||||
v = v.lstrip('#')
|
||||
lv = len(v)
|
||||
c = [int(v[i:i + lv // 3], 16) for i in range(0, lv, lv // 3)]
|
||||
yield c
|
||||
|
||||
|
||||
def rgb_to_hex(rgb_values):
|
||||
for v in rgb_values:
|
||||
c = "%02x%02x%02x" % tuple(v)
|
||||
yield c
|
||||
from prototorch.utils.colors import hex_to_rgb
|
||||
|
||||
|
||||
class Vis2DColorSOM(pl.Callback):
|
||||
@@ -93,7 +79,7 @@ if __name__ == "__main__":
|
||||
# Initialize the model
|
||||
model = pt.models.KohonenSOM(
|
||||
hparams,
|
||||
prototype_initializer=pt.components.Random(3),
|
||||
prototypes_initializer=pt.initializers.RNCI(3),
|
||||
)
|
||||
|
||||
# Compute intermediate input and output sizes
|
||||
|
@@ -2,23 +2,22 @@
|
||||
|
||||
import argparse
|
||||
|
||||
import prototorch as pt
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
|
||||
import prototorch as pt
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Command-line arguments
|
||||
parser = argparse.ArgumentParser()
|
||||
parser = pl.Trainer.add_argparse_args(parser)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Dataset
|
||||
train_ds = pt.datasets.Moons(num_samples=300, noise=0.2, seed=42)
|
||||
|
||||
# Reproducibility
|
||||
pl.utilities.seed.seed_everything(seed=2)
|
||||
|
||||
# Dataset
|
||||
train_ds = pt.datasets.Moons(num_samples=300, noise=0.2, seed=42)
|
||||
|
||||
# Dataloaders
|
||||
train_loader = torch.utils.data.DataLoader(train_ds,
|
||||
batch_size=256,
|
||||
@@ -32,8 +31,10 @@ if __name__ == "__main__":
|
||||
)
|
||||
|
||||
# Initialize the model
|
||||
model = pt.models.LGMLVQ(hparams,
|
||||
prototype_initializer=pt.components.SMI(train_ds))
|
||||
model = pt.models.LGMLVQ(
|
||||
hparams,
|
||||
prototypes_initializer=pt.initializers.SMCI(train_ds),
|
||||
)
|
||||
|
||||
# Compute intermediate input and output sizes
|
||||
model.example_input_array = torch.zeros(4, 2)
|
||||
|
@@ -3,11 +3,10 @@
|
||||
import argparse
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import prototorch as pt
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
|
||||
import prototorch as pt
|
||||
|
||||
|
||||
def plot_matrix(matrix):
|
||||
title = "Lambda matrix"
|
||||
@@ -40,20 +39,19 @@ if __name__ == "__main__":
|
||||
hparams = dict(
|
||||
distribution={
|
||||
"num_classes": 2,
|
||||
"prototypes_per_class": 1
|
||||
"per_class": 1,
|
||||
},
|
||||
input_dim=100,
|
||||
latent_dim=2,
|
||||
proto_lr=0.0001,
|
||||
bb_lr=0.0001,
|
||||
proto_lr=0.001,
|
||||
bb_lr=0.001,
|
||||
)
|
||||
|
||||
# Initialize the model
|
||||
model = pt.models.SiameseGMLVQ(
|
||||
hparams,
|
||||
# optimizer=torch.optim.SGD,
|
||||
optimizer=torch.optim.Adam,
|
||||
prototype_initializer=pt.components.SMI(train_ds),
|
||||
prototypes_initializer=pt.initializers.SMCI(train_ds),
|
||||
)
|
||||
|
||||
# Summary
|
||||
|
@@ -2,11 +2,10 @@
|
||||
|
||||
import argparse
|
||||
|
||||
import prototorch as pt
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
|
||||
import prototorch as pt
|
||||
|
||||
|
||||
class Backbone(torch.nn.Module):
|
||||
def __init__(self, input_size=4, hidden_size=10, latent_size=2):
|
||||
@@ -41,7 +40,7 @@ if __name__ == "__main__":
|
||||
|
||||
# Hyperparameters
|
||||
hparams = dict(
|
||||
distribution=[1, 2, 2],
|
||||
distribution=[3, 4, 5],
|
||||
proto_lr=0.001,
|
||||
bb_lr=0.001,
|
||||
)
|
||||
@@ -52,7 +51,10 @@ if __name__ == "__main__":
|
||||
# Initialize the model
|
||||
model = pt.models.LVQMLN(
|
||||
hparams,
|
||||
prototype_initializer=pt.components.SSI(train_ds, transform=backbone),
|
||||
prototypes_initializer=pt.initializers.SSCI(
|
||||
train_ds,
|
||||
transform=backbone,
|
||||
),
|
||||
backbone=backbone,
|
||||
)
|
||||
|
||||
@@ -67,11 +69,21 @@ if __name__ == "__main__":
|
||||
resolution=500,
|
||||
axis_off=True,
|
||||
)
|
||||
pruning = pt.models.PruneLoserPrototypes(
|
||||
threshold=0.01,
|
||||
idle_epochs=20,
|
||||
prune_quota_per_epoch=2,
|
||||
frequency=10,
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
# Setup trainer
|
||||
trainer = pl.Trainer.from_argparse_args(
|
||||
args,
|
||||
callbacks=[vis],
|
||||
callbacks=[
|
||||
vis,
|
||||
pruning,
|
||||
],
|
||||
)
|
||||
|
||||
# Training loop
|
||||
|
@@ -2,11 +2,9 @@
|
||||
|
||||
import argparse
|
||||
|
||||
import prototorch as pt
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
from torchvision.transforms import Lambda
|
||||
|
||||
import prototorch as pt
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Command-line arguments
|
||||
@@ -28,19 +26,17 @@ if __name__ == "__main__":
|
||||
distribution=[2, 2, 3],
|
||||
proto_lr=0.05,
|
||||
lambd=0.1,
|
||||
variance=1.0,
|
||||
input_dim=2,
|
||||
latent_dim=2,
|
||||
bb_lr=0.01,
|
||||
)
|
||||
|
||||
# Initialize the model
|
||||
model = pt.models.probabilistic.PLVQ(
|
||||
model = pt.models.RSLVQ(
|
||||
hparams,
|
||||
optimizer=torch.optim.Adam,
|
||||
# prototype_initializer=pt.components.SMI(train_ds),
|
||||
prototype_initializer=pt.components.SSI(train_ds, noise=0.2),
|
||||
# prototype_initializer=pt.components.Zeros(2),
|
||||
# prototype_initializer=pt.components.Ones(2, scale=2.0),
|
||||
prototypes_initializer=pt.initializers.SSCI(train_ds, noise=0.2),
|
||||
)
|
||||
|
||||
# Compute intermediate input and output sizes
|
||||
@@ -50,7 +46,7 @@ if __name__ == "__main__":
|
||||
print(model)
|
||||
|
||||
# Callbacks
|
||||
vis = pt.models.VisSiameseGLVQ2D(data=train_ds)
|
||||
vis = pt.models.VisGLVQ2D(data=train_ds)
|
||||
|
||||
# Setup trainer
|
||||
trainer = pl.Trainer.from_argparse_args(
|
||||
|
@@ -2,11 +2,10 @@
|
||||
|
||||
import argparse
|
||||
|
||||
import prototorch as pt
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
|
||||
import prototorch as pt
|
||||
|
||||
|
||||
class Backbone(torch.nn.Module):
|
||||
def __init__(self, input_size=4, hidden_size=10, latent_size=2):
|
||||
@@ -52,7 +51,7 @@ if __name__ == "__main__":
|
||||
# Initialize the model
|
||||
model = pt.models.SiameseGLVQ(
|
||||
hparams,
|
||||
prototype_initializer=pt.components.SMI(train_ds),
|
||||
prototypes_initializer=pt.initializers.SMCI(train_ds),
|
||||
backbone=backbone,
|
||||
both_path_gradients=False,
|
||||
)
|
||||
|
84
examples/warm_starting.py
Normal file
84
examples/warm_starting.py
Normal file
@@ -0,0 +1,84 @@
|
||||
"""Warm-starting GLVQ with prototypes from Growing Neural Gas."""
|
||||
|
||||
import argparse
|
||||
|
||||
import prototorch as pt
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
from torch.optim.lr_scheduler import ExponentialLR
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Command-line arguments
|
||||
parser = argparse.ArgumentParser()
|
||||
parser = pl.Trainer.add_argparse_args(parser)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Prepare the data
|
||||
train_ds = pt.datasets.Iris(dims=[0, 2])
|
||||
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=64)
|
||||
|
||||
# Initialize the gng
|
||||
gng = pt.models.GrowingNeuralGas(
|
||||
hparams=dict(num_prototypes=5, insert_freq=2, lr=0.1),
|
||||
prototypes_initializer=pt.initializers.ZCI(2),
|
||||
lr_scheduler=ExponentialLR,
|
||||
lr_scheduler_kwargs=dict(gamma=0.99, verbose=False),
|
||||
)
|
||||
|
||||
# Callbacks
|
||||
es = pl.callbacks.EarlyStopping(
|
||||
monitor="loss",
|
||||
min_delta=0.001,
|
||||
patience=20,
|
||||
mode="min",
|
||||
verbose=False,
|
||||
check_on_train_epoch_end=True,
|
||||
)
|
||||
|
||||
# Setup trainer for GNG
|
||||
trainer = pl.Trainer(
|
||||
max_epochs=200,
|
||||
callbacks=[es],
|
||||
weights_summary=None,
|
||||
)
|
||||
|
||||
# Training loop
|
||||
trainer.fit(gng, train_loader)
|
||||
|
||||
# Hyperparameters
|
||||
hparams = dict(
|
||||
distribution=[],
|
||||
lr=0.01,
|
||||
)
|
||||
|
||||
# Warm-start prototypes
|
||||
knn = pt.models.KNN(dict(k=1), data=train_ds)
|
||||
prototypes = gng.prototypes
|
||||
plabels = knn.predict(prototypes)
|
||||
|
||||
# Initialize the model
|
||||
model = pt.models.GLVQ(
|
||||
hparams,
|
||||
optimizer=torch.optim.Adam,
|
||||
prototypes_initializer=pt.initializers.LCI(prototypes),
|
||||
labels_initializer=pt.initializers.LLI(plabels),
|
||||
lr_scheduler=ExponentialLR,
|
||||
lr_scheduler_kwargs=dict(gamma=0.99, verbose=False),
|
||||
)
|
||||
|
||||
# Compute intermediate input and output sizes
|
||||
model.example_input_array = torch.zeros(4, 2)
|
||||
|
||||
# Callbacks
|
||||
vis = pt.models.VisGLVQ2D(data=train_ds)
|
||||
|
||||
# Setup trainer
|
||||
trainer = pl.Trainer.from_argparse_args(
|
||||
args,
|
||||
callbacks=[vis],
|
||||
weights_summary="full",
|
||||
accelerator="ddp",
|
||||
)
|
||||
|
||||
# Training loop
|
||||
trainer.fit(model, train_loader)
|
Reference in New Issue
Block a user