feat: metrics can be assigned to the different phases

This commit is contained in:
Alexander Engelsberger
2022-06-24 15:04:35 +02:00
parent 46ec7b07d7
commit 736565b768
6 changed files with 237 additions and 69 deletions

View File

@@ -1,7 +1,10 @@
import logging
import prototorch as pt
import pytorch_lightning as pl
import torchmetrics
from prototorch.core import SMCI
from prototorch.y.architectures.base import Steps
from prototorch.y.callbacks import (
LogTorchmetricCallback,
PlotLambdaMatrixToTensorboard,
@@ -9,7 +12,9 @@ from prototorch.y.callbacks import (
)
from prototorch.y.library.gmlvq import GMLVQ
from pytorch_lightning.callbacks import EarlyStopping
from torch.utils.data import DataLoader
from torch.utils.data import DataLoader, random_split
logging.basicConfig(level=logging.INFO)
# ##############################################################################
@@ -20,22 +25,42 @@ def main():
# ------------------------------------------------------------
# Dataset
train_ds = pt.datasets.Iris()
full_dataset = pt.datasets.Iris()
full_count = len(full_dataset)
train_count = int(full_count * 0.5)
val_count = int(full_count * 0.4)
test_count = int(full_count * 0.1)
train_dataset, val_dataset, test_dataset = random_split(
full_dataset, (train_count, val_count, test_count))
# Dataloader
train_loader = DataLoader(
train_ds,
batch_size=32,
num_workers=0,
train_dataset,
batch_size=1,
num_workers=4,
shuffle=True,
)
val_loader = DataLoader(
val_dataset,
batch_size=1,
num_workers=4,
shuffle=False,
)
test_loader = DataLoader(
test_dataset,
batch_size=1,
num_workers=0,
shuffle=False,
)
# ------------------------------------------------------------
# HYPERPARAMETERS
# ------------------------------------------------------------
# Select Initializer
components_initializer = SMCI(train_ds)
components_initializer = SMCI(full_dataset)
# Define Hyperparameters
hyperparameters = GMLVQ.HyperParameters(
@@ -51,17 +76,23 @@ def main():
# Create Model
model = GMLVQ(hyperparameters)
print(model.hparams)
# ------------------------------------------------------------
# TRAINING
# ------------------------------------------------------------
# Controlling Callbacks
stopping_criterion = LogTorchmetricCallback(
'recall',
recall = LogTorchmetricCallback(
'training_recall',
torchmetrics.Recall,
num_classes=3,
step=Steps.TRAINING,
)
stopping_criterion = LogTorchmetricCallback(
'validation_recall',
torchmetrics.Recall,
num_classes=3,
step=Steps.VALIDATION,
)
es = EarlyStopping(
@@ -71,18 +102,23 @@ def main():
)
# Visualization Callback
vis = VisGMLVQ2D(data=train_ds)
vis = VisGMLVQ2D(data=full_dataset)
# Define trainer
trainer = pl.Trainer(callbacks=[
vis,
stopping_criterion,
es,
PlotLambdaMatrixToTensorboard(),
], )
trainer = pl.Trainer(
callbacks=[
vis,
recall,
stopping_criterion,
es,
PlotLambdaMatrixToTensorboard(),
],
max_epochs=100,
)
# Train
trainer.fit(model, train_loader)
trainer.fit(model, train_loader, val_loader)
trainer.test(model, test_loader)
# Manual save
trainer.save_checkpoint("./y_arch.ckpt")
@@ -93,8 +129,6 @@ def main():
strict=True,
)
print(new_model.hparams)
if __name__ == "__main__":
main()