2022-05-18 12:11:46 +00:00
|
|
|
import prototorch as pt
|
|
|
|
import pytorch_lightning as pl
|
|
|
|
import torchmetrics
|
|
|
|
from prototorch.core import SMCI
|
2022-05-19 14:13:08 +00:00
|
|
|
from prototorch.models.y_arch.callbacks import (
|
2022-05-18 12:11:46 +00:00
|
|
|
LogTorchmetricCallback,
|
2022-05-19 14:13:08 +00:00
|
|
|
PlotLambdaMatrixToTensorboard,
|
|
|
|
VisGMLVQ2D,
|
2022-05-18 12:11:46 +00:00
|
|
|
)
|
2022-05-19 14:13:08 +00:00
|
|
|
from prototorch.models.y_arch.library.gmlvq import GMLVQ
|
2022-05-18 12:11:46 +00:00
|
|
|
from pytorch_lightning.callbacks import EarlyStopping
|
|
|
|
from torch.utils.data import DataLoader
|
|
|
|
|
|
|
|
# ##############################################################################
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
|
|
|
# ------------------------------------------------------------
|
|
|
|
# DATA
|
|
|
|
# ------------------------------------------------------------
|
|
|
|
|
|
|
|
# Dataset
|
2022-05-19 14:13:08 +00:00
|
|
|
train_ds = pt.datasets.Iris()
|
2022-05-18 12:11:46 +00:00
|
|
|
|
|
|
|
# Dataloader
|
|
|
|
train_loader = DataLoader(
|
|
|
|
train_ds,
|
2022-05-18 13:43:09 +00:00
|
|
|
batch_size=32,
|
2022-05-18 12:11:46 +00:00
|
|
|
num_workers=0,
|
|
|
|
shuffle=True,
|
|
|
|
)
|
|
|
|
|
|
|
|
# ------------------------------------------------------------
|
|
|
|
# HYPERPARAMETERS
|
|
|
|
# ------------------------------------------------------------
|
|
|
|
|
|
|
|
# Select Initializer
|
|
|
|
components_initializer = SMCI(train_ds)
|
|
|
|
|
|
|
|
# Define Hyperparameters
|
2022-05-19 14:13:08 +00:00
|
|
|
hyperparameters = GMLVQ.HyperParameters(
|
2022-05-18 13:43:09 +00:00
|
|
|
lr=0.1,
|
2022-05-19 14:13:08 +00:00
|
|
|
backbone_lr=5,
|
|
|
|
input_dim=4,
|
2022-05-18 12:11:46 +00:00
|
|
|
distribution=dict(
|
2022-05-19 14:13:08 +00:00
|
|
|
num_classes=3,
|
2022-05-18 12:11:46 +00:00
|
|
|
per_class=1,
|
|
|
|
),
|
|
|
|
component_initializer=components_initializer,
|
|
|
|
)
|
|
|
|
|
|
|
|
# Create Model
|
2022-05-19 14:13:08 +00:00
|
|
|
model = GMLVQ(hyperparameters)
|
2022-05-18 13:43:09 +00:00
|
|
|
|
2022-05-18 12:11:46 +00:00
|
|
|
print(model)
|
|
|
|
|
|
|
|
# ------------------------------------------------------------
|
|
|
|
# TRAINING
|
|
|
|
# ------------------------------------------------------------
|
|
|
|
|
|
|
|
# Controlling Callbacks
|
|
|
|
stopping_criterion = LogTorchmetricCallback(
|
|
|
|
'recall',
|
|
|
|
torchmetrics.Recall,
|
2022-05-19 14:13:08 +00:00
|
|
|
num_classes=3,
|
2022-05-18 12:11:46 +00:00
|
|
|
)
|
|
|
|
|
|
|
|
es = EarlyStopping(
|
|
|
|
monitor=stopping_criterion.name,
|
|
|
|
mode="max",
|
2022-05-19 14:13:08 +00:00
|
|
|
patience=10,
|
2022-05-18 12:11:46 +00:00
|
|
|
)
|
|
|
|
|
|
|
|
# Visualization Callback
|
2022-05-19 14:13:08 +00:00
|
|
|
vis = VisGMLVQ2D(data=train_ds)
|
2022-05-18 12:11:46 +00:00
|
|
|
|
|
|
|
# Define trainer
|
|
|
|
trainer = pl.Trainer(
|
|
|
|
callbacks=[
|
|
|
|
vis,
|
|
|
|
stopping_criterion,
|
|
|
|
es,
|
2022-05-19 14:13:08 +00:00
|
|
|
PlotLambdaMatrixToTensorboard(),
|
2022-05-18 12:11:46 +00:00
|
|
|
],
|
2022-05-19 14:13:08 +00:00
|
|
|
max_epochs=1000,
|
2022-05-18 12:11:46 +00:00
|
|
|
)
|
|
|
|
|
|
|
|
# Train
|
|
|
|
trainer.fit(model, train_loader)
|