prototorch_models/examples/y_architecture_example.py

90 lines
2.1 KiB
Python
Raw Normal View History

import prototorch as pt
import pytorch_lightning as pl
import torchmetrics
from prototorch.core import SMCI
2022-05-19 14:13:08 +00:00
from prototorch.models.y_arch.callbacks import (
LogTorchmetricCallback,
2022-05-19 14:13:08 +00:00
PlotLambdaMatrixToTensorboard,
VisGMLVQ2D,
)
2022-05-19 14:13:08 +00:00
from prototorch.models.y_arch.library.gmlvq import GMLVQ
from pytorch_lightning.callbacks import EarlyStopping
from torch.utils.data import DataLoader
# ##############################################################################
if __name__ == "__main__":
# ------------------------------------------------------------
# DATA
# ------------------------------------------------------------
# Dataset
2022-05-19 14:13:08 +00:00
train_ds = pt.datasets.Iris()
# Dataloader
train_loader = DataLoader(
train_ds,
2022-05-18 13:43:09 +00:00
batch_size=32,
num_workers=0,
shuffle=True,
)
# ------------------------------------------------------------
# HYPERPARAMETERS
# ------------------------------------------------------------
# Select Initializer
components_initializer = SMCI(train_ds)
# Define Hyperparameters
2022-05-19 14:13:08 +00:00
hyperparameters = GMLVQ.HyperParameters(
2022-05-18 13:43:09 +00:00
lr=0.1,
2022-05-19 14:13:08 +00:00
backbone_lr=5,
input_dim=4,
distribution=dict(
2022-05-19 14:13:08 +00:00
num_classes=3,
per_class=1,
),
component_initializer=components_initializer,
)
# Create Model
2022-05-19 14:13:08 +00:00
model = GMLVQ(hyperparameters)
2022-05-18 13:43:09 +00:00
print(model)
# ------------------------------------------------------------
# TRAINING
# ------------------------------------------------------------
# Controlling Callbacks
stopping_criterion = LogTorchmetricCallback(
'recall',
torchmetrics.Recall,
2022-05-19 14:13:08 +00:00
num_classes=3,
)
es = EarlyStopping(
monitor=stopping_criterion.name,
mode="max",
2022-05-19 14:13:08 +00:00
patience=10,
)
# Visualization Callback
2022-05-19 14:13:08 +00:00
vis = VisGMLVQ2D(data=train_ds)
# Define trainer
trainer = pl.Trainer(
callbacks=[
vis,
stopping_criterion,
es,
2022-05-19 14:13:08 +00:00
PlotLambdaMatrixToTensorboard(),
],
2022-05-19 14:13:08 +00:00
max_epochs=1000,
)
# Train
trainer.fit(model, train_loader)