feat: metrics can be assigned to the different phases
This commit is contained in:
@@ -1,7 +1,10 @@
|
||||
import logging
|
||||
|
||||
import prototorch as pt
|
||||
import pytorch_lightning as pl
|
||||
import torchmetrics
|
||||
from prototorch.core import SMCI
|
||||
from prototorch.y.architectures.base import Steps
|
||||
from prototorch.y.callbacks import (
|
||||
LogTorchmetricCallback,
|
||||
PlotLambdaMatrixToTensorboard,
|
||||
@@ -9,7 +12,9 @@ from prototorch.y.callbacks import (
|
||||
)
|
||||
from prototorch.y.library.gmlvq import GMLVQ
|
||||
from pytorch_lightning.callbacks import EarlyStopping
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data import DataLoader, random_split
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
# ##############################################################################
|
||||
|
||||
@@ -20,22 +25,42 @@ def main():
|
||||
# ------------------------------------------------------------
|
||||
|
||||
# Dataset
|
||||
train_ds = pt.datasets.Iris()
|
||||
full_dataset = pt.datasets.Iris()
|
||||
full_count = len(full_dataset)
|
||||
|
||||
train_count = int(full_count * 0.5)
|
||||
val_count = int(full_count * 0.4)
|
||||
test_count = int(full_count * 0.1)
|
||||
|
||||
train_dataset, val_dataset, test_dataset = random_split(
|
||||
full_dataset, (train_count, val_count, test_count))
|
||||
|
||||
# Dataloader
|
||||
train_loader = DataLoader(
|
||||
train_ds,
|
||||
batch_size=32,
|
||||
num_workers=0,
|
||||
train_dataset,
|
||||
batch_size=1,
|
||||
num_workers=4,
|
||||
shuffle=True,
|
||||
)
|
||||
val_loader = DataLoader(
|
||||
val_dataset,
|
||||
batch_size=1,
|
||||
num_workers=4,
|
||||
shuffle=False,
|
||||
)
|
||||
test_loader = DataLoader(
|
||||
test_dataset,
|
||||
batch_size=1,
|
||||
num_workers=0,
|
||||
shuffle=False,
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------
|
||||
# HYPERPARAMETERS
|
||||
# ------------------------------------------------------------
|
||||
|
||||
# Select Initializer
|
||||
components_initializer = SMCI(train_ds)
|
||||
components_initializer = SMCI(full_dataset)
|
||||
|
||||
# Define Hyperparameters
|
||||
hyperparameters = GMLVQ.HyperParameters(
|
||||
@@ -51,17 +76,23 @@ def main():
|
||||
# Create Model
|
||||
model = GMLVQ(hyperparameters)
|
||||
|
||||
print(model.hparams)
|
||||
|
||||
# ------------------------------------------------------------
|
||||
# TRAINING
|
||||
# ------------------------------------------------------------
|
||||
|
||||
# Controlling Callbacks
|
||||
stopping_criterion = LogTorchmetricCallback(
|
||||
'recall',
|
||||
recall = LogTorchmetricCallback(
|
||||
'training_recall',
|
||||
torchmetrics.Recall,
|
||||
num_classes=3,
|
||||
step=Steps.TRAINING,
|
||||
)
|
||||
|
||||
stopping_criterion = LogTorchmetricCallback(
|
||||
'validation_recall',
|
||||
torchmetrics.Recall,
|
||||
num_classes=3,
|
||||
step=Steps.VALIDATION,
|
||||
)
|
||||
|
||||
es = EarlyStopping(
|
||||
@@ -71,18 +102,23 @@ def main():
|
||||
)
|
||||
|
||||
# Visualization Callback
|
||||
vis = VisGMLVQ2D(data=train_ds)
|
||||
vis = VisGMLVQ2D(data=full_dataset)
|
||||
|
||||
# Define trainer
|
||||
trainer = pl.Trainer(callbacks=[
|
||||
vis,
|
||||
stopping_criterion,
|
||||
es,
|
||||
PlotLambdaMatrixToTensorboard(),
|
||||
], )
|
||||
trainer = pl.Trainer(
|
||||
callbacks=[
|
||||
vis,
|
||||
recall,
|
||||
stopping_criterion,
|
||||
es,
|
||||
PlotLambdaMatrixToTensorboard(),
|
||||
],
|
||||
max_epochs=100,
|
||||
)
|
||||
|
||||
# Train
|
||||
trainer.fit(model, train_loader)
|
||||
trainer.fit(model, train_loader, val_loader)
|
||||
trainer.test(model, test_loader)
|
||||
|
||||
# Manual save
|
||||
trainer.save_checkpoint("./y_arch.ckpt")
|
||||
@@ -93,8 +129,6 @@ def main():
|
||||
strict=True,
|
||||
)
|
||||
|
||||
print(new_model.hparams)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
Reference in New Issue
Block a user