feat: metrics can be assigned to the different phases

This commit is contained in:
Alexander Engelsberger
2022-06-24 15:04:35 +02:00
parent 46ec7b07d7
commit 736565b768
6 changed files with 237 additions and 69 deletions

View File

@@ -13,15 +13,34 @@ import torch
from torchmetrics import Metric
class Steps(enumerate):
TRAINING = "training"
VALIDATION = "validation"
TEST = "test"
PREDICT = "predict"
class BaseYArchitecture(pl.LightningModule):
@dataclass
class HyperParameters:
"""
Add all hyperparameters in the inherited class.
"""
...
# Fields
registered_metrics: dict[type[Metric], Metric] = {}
registered_metric_callbacks: dict[type[Metric], set[Callable]] = {}
registered_metrics: dict[str, dict[type[Metric], Metric]] = {
Steps.TRAINING: {},
Steps.VALIDATION: {},
Steps.TEST: {},
}
registered_metric_callbacks: dict[str, dict[type[Metric],
set[Callable]]] = {
Steps.TRAINING: {},
Steps.VALIDATION: {},
Steps.TEST: {},
}
# Type Hints for Necessary Fields
components_layer: torch.nn.Module
@@ -41,7 +60,7 @@ class BaseYArchitecture(pl.LightningModule):
# Common Steps
self.init_components(hparams)
self.init_latent(hparams)
self.init_backbone(hparams)
self.init_comparison(hparams)
self.init_competition(hparams)
@@ -53,7 +72,7 @@ class BaseYArchitecture(pl.LightningModule):
# external API
def get_competition(self, batch, components):
latent_batch, latent_components = self.latent(batch, components)
latent_batch, latent_components = self.backbone(batch, components)
# TODO: => Latent Hook
comparison_tensor = self.comparison(latent_batch, latent_components)
# TODO: => Comparison Hook
@@ -92,27 +111,43 @@ class BaseYArchitecture(pl.LightningModule):
return self.loss(comparison_tensor, batch, components)
# Empty Initialization
# TODO: Docs
def init_components(self, hparams: HyperParameters) -> None:
"""
All initialization necessary for the components step.
"""
...
def init_latent(self, hparams: HyperParameters) -> None:
def init_backbone(self, hparams: HyperParameters) -> None:
"""
All initialization necessary for the backbone step.
"""
...
def init_comparison(self, hparams: HyperParameters) -> None:
"""
All initialization necessary for the comparison step.
"""
...
def init_competition(self, hparams: HyperParameters) -> None:
"""
All initialization necessary for the competition step.
"""
...
def init_loss(self, hparams: HyperParameters) -> None:
"""
All initialization necessary for the loss step.
"""
...
def init_inference(self, hparams: HyperParameters) -> None:
"""
All initialization necessary for the inference step.
"""
...
# Empty Steps
# TODO: Type hints
def components(self):
"""
This step has no input.
@@ -122,9 +157,9 @@ class BaseYArchitecture(pl.LightningModule):
raise NotImplementedError(
"The components step has no reasonable default.")
def latent(self, batch, components):
def backbone(self, batch, components):
"""
The latent step receives the data batch and the components.
The backbone step receives the data batch and the components.
It can transform both by an arbitrary function.
It returns the transformed batch and components, each of the same length as the original input.
@@ -173,52 +208,72 @@ class BaseYArchitecture(pl.LightningModule):
self,
name: Callable,
metric: type[Metric],
step: str = Steps.TRAINING,
**metric_kwargs,
):
if metric not in self.registered_metrics:
self.registered_metrics[metric] = metric(**metric_kwargs)
self.registered_metric_callbacks[metric] = {name}
else:
self.registered_metric_callbacks[metric].add(name)
if step == Steps.PREDICT:
raise ValueError("Prediction metrics are not supported.")
def update_metrics_step(self, batch):
if metric not in self.registered_metrics:
self.registered_metrics[step][metric] = metric(**metric_kwargs)
self.registered_metric_callbacks[step][metric] = {name}
else:
self.registered_metric_callbacks[step][metric].add(name)
def update_metrics_step(self, batch, step):
# Prediction Metrics
preds = self(batch)
x, y = batch
for metric in self.registered_metrics:
instance = self.registered_metrics[metric].to(self.device)
for metric in self.registered_metrics[step]:
instance = self.registered_metrics[step][metric].to(self.device)
instance(y, preds)
def update_metrics_epoch(self):
for metric in self.registered_metrics:
instance = self.registered_metrics[metric].to(self.device)
def update_metrics_epoch(self, step):
for metric in self.registered_metrics[step]:
instance = self.registered_metrics[step][metric].to(self.device)
value = instance.compute()
for callback in self.registered_metric_callbacks[metric]:
for callback in self.registered_metric_callbacks[step][metric]:
callback(value, self)
instance.reset()
# Lightning Hooks
# Steps
# Lightning steps
# -------------------------------------------------------------------------
# >>>> Training
def training_step(self, batch, batch_idx, optimizer_idx=None):
self.update_metrics_step([torch.clone(el) for el in batch])
self.update_metrics_step(batch, Steps.TRAINING)
return self.loss_forward(batch)
def validation_step(self, batch, batch_idx):
return self.loss_forward(batch)
def test_step(self, batch, batch_idx):
return self.loss_forward(batch)
# Other Hooks
def training_epoch_end(self, outs) -> None:
self.update_metrics_epoch()
self.update_metrics_epoch(Steps.TRAINING)
# >>>> Validation
def validation_step(self, batch, batch_idx):
self.update_metrics_step(batch, Steps.VALIDATION)
return self.loss_forward(batch)
def validation_epoch_end(self, outs) -> None:
self.update_metrics_epoch(Steps.VALIDATION)
# >>>> Test
def test_step(self, batch, batch_idx):
self.update_metrics_step(batch, Steps.TEST)
return self.loss_forward(batch)
def test_epoch_end(self, outs) -> None:
self.update_metrics_epoch(Steps.TEST)
# >>>> Prediction
def predict_step(self, batch, batch_idx, dataloader_idx=0):
return self.predict(batch)
# Check points
def on_save_checkpoint(self, checkpoint: dict[str, Any]) -> None:
# Compatible with Lightning
checkpoint["hyper_parameters"] = {
'hparams': checkpoint["hyper_parameters"]
}

View File

@@ -1,3 +1,4 @@
import logging
import warnings
from typing import Optional, Type
@@ -8,7 +9,7 @@ import torchmetrics
from matplotlib import pyplot as plt
from prototorch.models.vis import Vis2DAbstract
from prototorch.utils.utils import mesh2d
from prototorch.y.architectures.base import BaseYArchitecture
from prototorch.y.architectures.base import BaseYArchitecture, Steps
from prototorch.y.library.gmlvq import GMLVQ
from pytorch_lightning.loggers import TensorBoardLogger
@@ -34,13 +35,13 @@ class LogTorchmetricCallback(pl.Callback):
self,
name,
metric: Type[torchmetrics.Metric],
on="prediction",
step: str = Steps.TRAINING,
**metric_kwargs,
) -> None:
self.name = name
self.metric = metric
self.metric_kwargs = metric_kwargs
self.on = on
self.step = step
def setup(
self,
@@ -48,14 +49,12 @@ class LogTorchmetricCallback(pl.Callback):
pl_module: BaseYArchitecture,
stage: Optional[str] = None,
) -> None:
if self.on == "prediction":
pl_module.register_torchmetric(
self,
self.metric,
**self.metric_kwargs,
)
else:
raise ValueError(f"{self.on} is no valid metric hook")
pl_module.register_torchmetric(
self,
self.metric,
step=self.step,
**self.metric_kwargs,
)
def __call__(self, value, pl_module: BaseYArchitecture):
pl_module.log(self.name, value)