prototorch_models/examples/gmlvq_iris.py

135 lines
3.2 KiB
Python
Raw Normal View History

2022-08-15 10:14:14 +00:00
import logging
2021-09-01 08:49:57 +00:00
import pytorch_lightning as pl
2022-08-15 10:14:14 +00:00
import torchmetrics
from prototorch.core import SMCI
from prototorch.datasets import Iris
from prototorch.models.architectures.base import Steps
from prototorch.models.callbacks import (
LogTorchmetricCallback,
PlotLambdaMatrixToTensorboard,
VisGMLVQ2D,
)
from prototorch.models.library.gmlvq import GMLVQ
from pytorch_lightning.callbacks import EarlyStopping
from torch.utils.data import DataLoader, random_split
2022-08-15 10:14:14 +00:00
logging.basicConfig(level=logging.INFO)
2021-09-01 08:49:57 +00:00
2022-08-15 10:14:14 +00:00
# ##############################################################################
2022-08-15 10:14:14 +00:00
def main():
# ------------------------------------------------------------
# DATA
# ------------------------------------------------------------
2021-09-01 08:49:57 +00:00
# Dataset
2022-08-15 10:14:14 +00:00
full_dataset = Iris()
full_count = len(full_dataset)
train_count = int(full_count * 0.5)
val_count = int(full_count * 0.4)
test_count = int(full_count * 0.1)
train_dataset, val_dataset, test_dataset = random_split(
full_dataset, (train_count, val_count, test_count))
# Dataloader
train_loader = DataLoader(
train_dataset,
batch_size=1,
num_workers=4,
shuffle=True,
)
val_loader = DataLoader(
val_dataset,
batch_size=1,
num_workers=4,
shuffle=False,
)
test_loader = DataLoader(
test_dataset,
batch_size=1,
num_workers=0,
shuffle=False,
)
2021-09-01 08:49:57 +00:00
2022-08-15 10:14:14 +00:00
# ------------------------------------------------------------
# HYPERPARAMETERS
# ------------------------------------------------------------
2021-09-01 08:49:57 +00:00
2022-08-15 10:14:14 +00:00
# Select Initializer
components_initializer = SMCI(full_dataset)
# Define Hyperparameters
hyperparameters = GMLVQ.HyperParameters(
lr=dict(components_layer=0.1, _omega=0),
2021-09-01 08:49:57 +00:00
input_dim=4,
2022-08-15 10:14:14 +00:00
distribution=dict(
num_classes=3,
per_class=1,
),
component_initializer=components_initializer,
)
# Create Model
model = GMLVQ(hyperparameters)
# ------------------------------------------------------------
# TRAINING
# ------------------------------------------------------------
# Controlling Callbacks
recall = LogTorchmetricCallback(
'training_recall',
torchmetrics.Recall,
num_classes=3,
step=Steps.TRAINING,
2021-09-01 08:49:57 +00:00
)
2022-08-15 10:14:14 +00:00
stopping_criterion = LogTorchmetricCallback(
'validation_recall',
torchmetrics.Recall,
num_classes=3,
step=Steps.VALIDATION,
2021-09-01 08:49:57 +00:00
)
2022-08-15 10:14:14 +00:00
es = EarlyStopping(
monitor=stopping_criterion.name,
mode="max",
patience=10,
)
2021-09-01 08:49:57 +00:00
2022-08-15 10:14:14 +00:00
# Visualization Callback
vis = VisGMLVQ2D(data=full_dataset)
2021-09-01 08:49:57 +00:00
2022-08-15 10:14:14 +00:00
# Define trainer
trainer = pl.Trainer(
callbacks=[
vis,
2022-08-15 10:14:14 +00:00
recall,
stopping_criterion,
es,
PlotLambdaMatrixToTensorboard(),
],
max_epochs=100,
2021-09-01 08:49:57 +00:00
)
2022-08-15 10:14:14 +00:00
# Train
trainer.fit(model, train_loader, val_loader)
trainer.test(model, test_loader)
# Manual save
trainer.save_checkpoint("./y_arch.ckpt")
# Load saved model
new_model = GMLVQ.load_from_checkpoint(
checkpoint_path="./y_arch.ckpt",
strict=True,
)
if __name__ == "__main__":
main()