feat: metric callback defaults on epoch

This commit is contained in:
Alexander Engelsberger 2022-08-26 10:58:33 +02:00
parent c3cad19853
commit 16ca409f07
No known key found for this signature in database
GPG Key ID: DE8669706B6AC2E7
2 changed files with 33 additions and 17 deletions

View File

@ -46,15 +46,15 @@ class BaseYArchitecture(pl.LightningModule):
components_layer: torch.nn.Module components_layer: torch.nn.Module
def __init__(self, hparams) -> None: def __init__(self, hparams) -> None:
if type(hparams) is dict: if isinstance(hparams, dict):
self.save_hyperparameters(hparams) self.save_hyperparameters(hparams)
# TODO: => Move into Component Child # TODO: => Move into Component Child
del hparams["initialized_proto_shape"] del hparams["initialized_proto_shape"]
hparams = self.HyperParameters(**hparams) hparams = self.HyperParameters(**hparams)
else: else:
hparam_dict = asdict(hparams) hparams_dict = asdict(hparams)
hparam_dict["component_initializer"] = None hparams_dict["component_initializer"] = None
self.save_hyperparameters(hparam_dict, ) self.save_hyperparameters(hparams_dict, )
super().__init__() super().__init__()
@ -72,6 +72,9 @@ class BaseYArchitecture(pl.LightningModule):
# external API # external API
def get_competition(self, batch, components): def get_competition(self, batch, components):
'''
Returns the output of the competition layer.
'''
latent_batch, latent_components = self.backbone(batch, components) latent_batch, latent_components = self.backbone(batch, components)
# TODO: => Latent Hook # TODO: => Latent Hook
comparison_tensor = self.comparison(latent_batch, latent_components) comparison_tensor = self.comparison(latent_batch, latent_components)
@ -79,6 +82,9 @@ class BaseYArchitecture(pl.LightningModule):
return comparison_tensor return comparison_tensor
def forward(self, batch): def forward(self, batch):
'''
Returns the prediction.
'''
if isinstance(batch, torch.Tensor): if isinstance(batch, torch.Tensor):
batch = (batch, None) batch = (batch, None)
# TODO: manage different datatypes? # TODO: manage different datatypes?
@ -95,6 +101,9 @@ class BaseYArchitecture(pl.LightningModule):
return self.forward(batch) return self.forward(batch)
def forward_comparison(self, batch): def forward_comparison(self, batch):
'''
Returns the Output of the comparison layer.
'''
if isinstance(batch, torch.Tensor): if isinstance(batch, torch.Tensor):
batch = (batch, None) batch = (batch, None)
# TODO: manage different datatypes? # TODO: manage different datatypes?
@ -103,6 +112,9 @@ class BaseYArchitecture(pl.LightningModule):
return self.get_competition(batch, components) return self.get_competition(batch, components)
def loss_forward(self, batch): def loss_forward(self, batch):
'''
Returns the output of the loss layer.
'''
# TODO: manage different datatypes? # TODO: manage different datatypes?
components = self.components_layer() components = self.components_layer()
# TODO: => Component Hook # TODO: => Component Hook
@ -115,37 +127,31 @@ class BaseYArchitecture(pl.LightningModule):
""" """
All initialization necessary for the components step. All initialization necessary for the components step.
""" """
...
def init_backbone(self, hparams: HyperParameters) -> None: def init_backbone(self, hparams: HyperParameters) -> None:
""" """
All initialization necessary for the backbone step. All initialization necessary for the backbone step.
""" """
...
def init_comparison(self, hparams: HyperParameters) -> None: def init_comparison(self, hparams: HyperParameters) -> None:
""" """
All initialization necessary for the comparison step. All initialization necessary for the comparison step.
""" """
...
def init_competition(self, hparams: HyperParameters) -> None: def init_competition(self, hparams: HyperParameters) -> None:
""" """
All initialization necessary for the competition step. All initialization necessary for the competition step.
""" """
...
def init_loss(self, hparams: HyperParameters) -> None: def init_loss(self, hparams: HyperParameters) -> None:
""" """
All initialization necessary for the loss step. All initialization necessary for the loss step.
""" """
...
def init_inference(self, hparams: HyperParameters) -> None: def init_inference(self, hparams: HyperParameters) -> None:
""" """
All initialization necessary for the inference step. All initialization necessary for the inference step.
""" """
...
# Empty Steps # Empty Steps
def components(self): def components(self):
@ -162,7 +168,8 @@ class BaseYArchitecture(pl.LightningModule):
The backbone step receives the data batch and the components. The backbone step receives the data batch and the components.
It can transform both by an arbitrary function. It can transform both by an arbitrary function.
It returns the transformed batch and components, each of the same length as the original input. It returns the transformed batch and components,
each of the same length as the original input.
""" """
return batch, components return batch, components
@ -211,6 +218,9 @@ class BaseYArchitecture(pl.LightningModule):
step: str = Steps.TRAINING, step: str = Steps.TRAINING,
**metric_kwargs, **metric_kwargs,
): ):
'''
Register a callback for evaluating a torchmetric.
'''
if step == Steps.PREDICT: if step == Steps.PREDICT:
raise ValueError("Prediction metrics are not supported.") raise ValueError("Prediction metrics are not supported.")
@ -224,7 +234,7 @@ class BaseYArchitecture(pl.LightningModule):
# Prediction Metrics # Prediction Metrics
preds = self(batch) preds = self(batch)
x, y = batch _, y = batch
for metric in self.registered_metrics[step]: for metric in self.registered_metrics[step]:
instance = self.registered_metrics[step][metric].to(self.device) instance = self.registered_metrics[step][metric].to(self.device)
instance(y, preds) instance(y, preds)
@ -247,7 +257,7 @@ class BaseYArchitecture(pl.LightningModule):
return self.loss_forward(batch) return self.loss_forward(batch)
def training_epoch_end(self, outs) -> None: def training_epoch_end(self, outputs) -> None:
self.update_metrics_epoch(Steps.TRAINING) self.update_metrics_epoch(Steps.TRAINING)
# >>>> Validation # >>>> Validation
@ -256,7 +266,7 @@ class BaseYArchitecture(pl.LightningModule):
return self.loss_forward(batch) return self.loss_forward(batch)
def validation_epoch_end(self, outs) -> None: def validation_epoch_end(self, outputs) -> None:
self.update_metrics_epoch(Steps.VALIDATION) self.update_metrics_epoch(Steps.VALIDATION)
# >>>> Test # >>>> Test
@ -264,7 +274,7 @@ class BaseYArchitecture(pl.LightningModule):
self.update_metrics_step(batch, Steps.TEST) self.update_metrics_step(batch, Steps.TEST)
return self.loss_forward(batch) return self.loss_forward(batch)
def test_epoch_end(self, outs) -> None: def test_epoch_end(self, outputs) -> None:
self.update_metrics_epoch(Steps.TEST) self.update_metrics_epoch(Steps.TEST)
# >>>> Prediction # >>>> Prediction

View File

@ -1,4 +1,3 @@
import logging
import warnings import warnings
from typing import Optional, Type from typing import Optional, Type
@ -36,12 +35,14 @@ class LogTorchmetricCallback(pl.Callback):
name, name,
metric: Type[torchmetrics.Metric], metric: Type[torchmetrics.Metric],
step: str = Steps.TRAINING, step: str = Steps.TRAINING,
on_epoch=True,
**metric_kwargs, **metric_kwargs,
) -> None: ) -> None:
self.name = name self.name = name
self.metric = metric self.metric = metric
self.metric_kwargs = metric_kwargs self.metric_kwargs = metric_kwargs
self.step = step self.step = step
self.on_epoch = on_epoch
def setup( def setup(
self, self,
@ -57,7 +58,12 @@ class LogTorchmetricCallback(pl.Callback):
) )
def __call__(self, value, pl_module: BaseYArchitecture): def __call__(self, value, pl_module: BaseYArchitecture):
pl_module.log(self.name, value) pl_module.log(
self.name,
value,
on_epoch=self.on_epoch,
on_step=(not self.on_epoch),
)
class LogConfusionMatrix(LogTorchmetricCallback): class LogConfusionMatrix(LogTorchmetricCallback):