feat: metric callback defaults on epoch

This commit is contained in:
Alexander Engelsberger 2022-08-26 10:58:33 +02:00
parent c3cad19853
commit 16ca409f07
No known key found for this signature in database
GPG Key ID: DE8669706B6AC2E7
2 changed files with 33 additions and 17 deletions

View File

@ -46,15 +46,15 @@ class BaseYArchitecture(pl.LightningModule):
components_layer: torch.nn.Module
def __init__(self, hparams) -> None:
if type(hparams) is dict:
if isinstance(hparams, dict):
self.save_hyperparameters(hparams)
# TODO: => Move into Component Child
del hparams["initialized_proto_shape"]
hparams = self.HyperParameters(**hparams)
else:
hparam_dict = asdict(hparams)
hparam_dict["component_initializer"] = None
self.save_hyperparameters(hparam_dict, )
hparams_dict = asdict(hparams)
hparams_dict["component_initializer"] = None
self.save_hyperparameters(hparams_dict, )
super().__init__()
@ -72,6 +72,9 @@ class BaseYArchitecture(pl.LightningModule):
# external API
def get_competition(self, batch, components):
'''
Returns the output of the competition layer.
'''
latent_batch, latent_components = self.backbone(batch, components)
# TODO: => Latent Hook
comparison_tensor = self.comparison(latent_batch, latent_components)
@ -79,6 +82,9 @@ class BaseYArchitecture(pl.LightningModule):
return comparison_tensor
def forward(self, batch):
'''
Returns the prediction.
'''
if isinstance(batch, torch.Tensor):
batch = (batch, None)
# TODO: manage different datatypes?
@ -95,6 +101,9 @@ class BaseYArchitecture(pl.LightningModule):
return self.forward(batch)
def forward_comparison(self, batch):
'''
Returns the Output of the comparison layer.
'''
if isinstance(batch, torch.Tensor):
batch = (batch, None)
# TODO: manage different datatypes?
@ -103,6 +112,9 @@ class BaseYArchitecture(pl.LightningModule):
return self.get_competition(batch, components)
def loss_forward(self, batch):
'''
Returns the output of the loss layer.
'''
# TODO: manage different datatypes?
components = self.components_layer()
# TODO: => Component Hook
@ -115,37 +127,31 @@ class BaseYArchitecture(pl.LightningModule):
"""
All initialization necessary for the components step.
"""
...
def init_backbone(self, hparams: HyperParameters) -> None:
"""
All initialization necessary for the backbone step.
"""
...
def init_comparison(self, hparams: HyperParameters) -> None:
"""
All initialization necessary for the comparison step.
"""
...
def init_competition(self, hparams: HyperParameters) -> None:
"""
All initialization necessary for the competition step.
"""
...
def init_loss(self, hparams: HyperParameters) -> None:
"""
All initialization necessary for the loss step.
"""
...
def init_inference(self, hparams: HyperParameters) -> None:
"""
All initialization necessary for the inference step.
"""
...
# Empty Steps
def components(self):
@ -162,7 +168,8 @@ class BaseYArchitecture(pl.LightningModule):
The backbone step receives the data batch and the components.
It can transform both by an arbitrary function.
It returns the transformed batch and components, each of the same length as the original input.
It returns the transformed batch and components,
each of the same length as the original input.
"""
return batch, components
@ -211,6 +218,9 @@ class BaseYArchitecture(pl.LightningModule):
step: str = Steps.TRAINING,
**metric_kwargs,
):
'''
Register a callback for evaluating a torchmetric.
'''
if step == Steps.PREDICT:
raise ValueError("Prediction metrics are not supported.")
@ -224,7 +234,7 @@ class BaseYArchitecture(pl.LightningModule):
# Prediction Metrics
preds = self(batch)
x, y = batch
_, y = batch
for metric in self.registered_metrics[step]:
instance = self.registered_metrics[step][metric].to(self.device)
instance(y, preds)
@ -247,7 +257,7 @@ class BaseYArchitecture(pl.LightningModule):
return self.loss_forward(batch)
def training_epoch_end(self, outs) -> None:
def training_epoch_end(self, outputs) -> None:
self.update_metrics_epoch(Steps.TRAINING)
# >>>> Validation
@ -256,7 +266,7 @@ class BaseYArchitecture(pl.LightningModule):
return self.loss_forward(batch)
def validation_epoch_end(self, outs) -> None:
def validation_epoch_end(self, outputs) -> None:
self.update_metrics_epoch(Steps.VALIDATION)
# >>>> Test
@ -264,7 +274,7 @@ class BaseYArchitecture(pl.LightningModule):
self.update_metrics_step(batch, Steps.TEST)
return self.loss_forward(batch)
def test_epoch_end(self, outs) -> None:
def test_epoch_end(self, outputs) -> None:
self.update_metrics_epoch(Steps.TEST)
# >>>> Prediction

View File

@ -1,4 +1,3 @@
import logging
import warnings
from typing import Optional, Type
@ -36,12 +35,14 @@ class LogTorchmetricCallback(pl.Callback):
name,
metric: Type[torchmetrics.Metric],
step: str = Steps.TRAINING,
on_epoch=True,
**metric_kwargs,
) -> None:
self.name = name
self.metric = metric
self.metric_kwargs = metric_kwargs
self.step = step
self.on_epoch = on_epoch
def setup(
self,
@ -57,7 +58,12 @@ class LogTorchmetricCallback(pl.Callback):
)
def __call__(self, value, pl_module: BaseYArchitecture):
pl_module.log(self.name, value)
pl_module.log(
self.name,
value,
on_epoch=self.on_epoch,
on_step=(not self.on_epoch),
)
class LogConfusionMatrix(LogTorchmetricCallback):