2022-05-18 12:11:46 +00:00
|
|
|
import prototorch as pt
|
|
|
|
import pytorch_lightning as pl
|
|
|
|
import torchmetrics
|
|
|
|
from prototorch.core import SMCI
|
2022-06-03 08:39:11 +00:00
|
|
|
from prototorch.y.callbacks import (
|
2022-05-18 12:11:46 +00:00
|
|
|
LogTorchmetricCallback,
|
2022-05-19 14:13:08 +00:00
|
|
|
PlotLambdaMatrixToTensorboard,
|
|
|
|
VisGMLVQ2D,
|
2022-05-18 12:11:46 +00:00
|
|
|
)
|
2022-06-03 08:39:11 +00:00
|
|
|
from prototorch.y.library.gmlvq import GMLVQ
|
2022-05-18 12:11:46 +00:00
|
|
|
from pytorch_lightning.callbacks import EarlyStopping
|
|
|
|
from torch.utils.data import DataLoader
|
|
|
|
|
|
|
|
# ##############################################################################
|
|
|
|
|
|
|
|
|
2022-06-12 08:36:15 +00:00
|
|
|
def main():
|
2022-05-18 12:11:46 +00:00
|
|
|
# ------------------------------------------------------------
|
|
|
|
# DATA
|
|
|
|
# ------------------------------------------------------------
|
|
|
|
|
|
|
|
# Dataset
|
2022-05-19 14:13:08 +00:00
|
|
|
train_ds = pt.datasets.Iris()
|
2022-05-18 12:11:46 +00:00
|
|
|
|
|
|
|
# Dataloader
|
|
|
|
train_loader = DataLoader(
|
|
|
|
train_ds,
|
2022-05-18 13:43:09 +00:00
|
|
|
batch_size=32,
|
2022-05-18 12:11:46 +00:00
|
|
|
num_workers=0,
|
|
|
|
shuffle=True,
|
|
|
|
)
|
|
|
|
|
|
|
|
# ------------------------------------------------------------
|
|
|
|
# HYPERPARAMETERS
|
|
|
|
# ------------------------------------------------------------
|
|
|
|
|
|
|
|
# Select Initializer
|
|
|
|
components_initializer = SMCI(train_ds)
|
|
|
|
|
|
|
|
# Define Hyperparameters
|
2022-05-19 14:13:08 +00:00
|
|
|
hyperparameters = GMLVQ.HyperParameters(
|
2022-05-31 15:56:03 +00:00
|
|
|
lr=dict(components_layer=0.1, _omega=0),
|
2022-05-19 14:13:08 +00:00
|
|
|
input_dim=4,
|
2022-05-18 12:11:46 +00:00
|
|
|
distribution=dict(
|
2022-05-19 14:13:08 +00:00
|
|
|
num_classes=3,
|
2022-05-18 12:11:46 +00:00
|
|
|
per_class=1,
|
|
|
|
),
|
|
|
|
component_initializer=components_initializer,
|
|
|
|
)
|
|
|
|
|
|
|
|
# Create Model
|
2022-05-19 14:13:08 +00:00
|
|
|
model = GMLVQ(hyperparameters)
|
2022-05-18 13:43:09 +00:00
|
|
|
|
2022-06-12 08:36:15 +00:00
|
|
|
print(model.hparams)
|
2022-05-18 12:11:46 +00:00
|
|
|
|
|
|
|
# ------------------------------------------------------------
|
|
|
|
# TRAINING
|
|
|
|
# ------------------------------------------------------------
|
|
|
|
|
|
|
|
# Controlling Callbacks
|
|
|
|
stopping_criterion = LogTorchmetricCallback(
|
|
|
|
'recall',
|
|
|
|
torchmetrics.Recall,
|
2022-05-19 14:13:08 +00:00
|
|
|
num_classes=3,
|
2022-05-18 12:11:46 +00:00
|
|
|
)
|
|
|
|
|
|
|
|
es = EarlyStopping(
|
|
|
|
monitor=stopping_criterion.name,
|
|
|
|
mode="max",
|
2022-05-19 14:13:08 +00:00
|
|
|
patience=10,
|
2022-05-18 12:11:46 +00:00
|
|
|
)
|
|
|
|
|
|
|
|
# Visualization Callback
|
2022-05-19 14:13:08 +00:00
|
|
|
vis = VisGMLVQ2D(data=train_ds)
|
2022-05-18 12:11:46 +00:00
|
|
|
|
|
|
|
# Define trainer
|
2022-06-12 08:36:15 +00:00
|
|
|
trainer = pl.Trainer(callbacks=[
|
|
|
|
vis,
|
|
|
|
stopping_criterion,
|
|
|
|
es,
|
|
|
|
PlotLambdaMatrixToTensorboard(),
|
|
|
|
], )
|
2022-05-18 12:11:46 +00:00
|
|
|
|
|
|
|
# Train
|
|
|
|
trainer.fit(model, train_loader)
|
2022-06-12 08:36:15 +00:00
|
|
|
|
|
|
|
# Manual save
|
|
|
|
trainer.save_checkpoint("./y_arch.ckpt")
|
|
|
|
|
|
|
|
# Load saved model
|
|
|
|
new_model = GMLVQ.load_from_checkpoint(
|
|
|
|
checkpoint_path="./y_arch.ckpt",
|
|
|
|
strict=True,
|
|
|
|
)
|
|
|
|
|
|
|
|
print(new_model.hparams)
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
main()
|