feat: metric callback defaults on epoch
This commit is contained in:
parent
c3cad19853
commit
16ca409f07
@ -46,15 +46,15 @@ class BaseYArchitecture(pl.LightningModule):
|
||||
components_layer: torch.nn.Module
|
||||
|
||||
def __init__(self, hparams) -> None:
|
||||
if type(hparams) is dict:
|
||||
if isinstance(hparams, dict):
|
||||
self.save_hyperparameters(hparams)
|
||||
# TODO: => Move into Component Child
|
||||
del hparams["initialized_proto_shape"]
|
||||
hparams = self.HyperParameters(**hparams)
|
||||
else:
|
||||
hparam_dict = asdict(hparams)
|
||||
hparam_dict["component_initializer"] = None
|
||||
self.save_hyperparameters(hparam_dict, )
|
||||
hparams_dict = asdict(hparams)
|
||||
hparams_dict["component_initializer"] = None
|
||||
self.save_hyperparameters(hparams_dict, )
|
||||
|
||||
super().__init__()
|
||||
|
||||
@ -72,6 +72,9 @@ class BaseYArchitecture(pl.LightningModule):
|
||||
|
||||
# external API
|
||||
def get_competition(self, batch, components):
|
||||
'''
|
||||
Returns the output of the competition layer.
|
||||
'''
|
||||
latent_batch, latent_components = self.backbone(batch, components)
|
||||
# TODO: => Latent Hook
|
||||
comparison_tensor = self.comparison(latent_batch, latent_components)
|
||||
@ -79,6 +82,9 @@ class BaseYArchitecture(pl.LightningModule):
|
||||
return comparison_tensor
|
||||
|
||||
def forward(self, batch):
|
||||
'''
|
||||
Returns the prediction.
|
||||
'''
|
||||
if isinstance(batch, torch.Tensor):
|
||||
batch = (batch, None)
|
||||
# TODO: manage different datatypes?
|
||||
@ -95,6 +101,9 @@ class BaseYArchitecture(pl.LightningModule):
|
||||
return self.forward(batch)
|
||||
|
||||
def forward_comparison(self, batch):
|
||||
'''
|
||||
Returns the Output of the comparison layer.
|
||||
'''
|
||||
if isinstance(batch, torch.Tensor):
|
||||
batch = (batch, None)
|
||||
# TODO: manage different datatypes?
|
||||
@ -103,6 +112,9 @@ class BaseYArchitecture(pl.LightningModule):
|
||||
return self.get_competition(batch, components)
|
||||
|
||||
def loss_forward(self, batch):
|
||||
'''
|
||||
Returns the output of the loss layer.
|
||||
'''
|
||||
# TODO: manage different datatypes?
|
||||
components = self.components_layer()
|
||||
# TODO: => Component Hook
|
||||
@ -115,37 +127,31 @@ class BaseYArchitecture(pl.LightningModule):
|
||||
"""
|
||||
All initialization necessary for the components step.
|
||||
"""
|
||||
...
|
||||
|
||||
def init_backbone(self, hparams: HyperParameters) -> None:
|
||||
"""
|
||||
All initialization necessary for the backbone step.
|
||||
"""
|
||||
...
|
||||
|
||||
def init_comparison(self, hparams: HyperParameters) -> None:
|
||||
"""
|
||||
All initialization necessary for the comparison step.
|
||||
"""
|
||||
...
|
||||
|
||||
def init_competition(self, hparams: HyperParameters) -> None:
|
||||
"""
|
||||
All initialization necessary for the competition step.
|
||||
"""
|
||||
...
|
||||
|
||||
def init_loss(self, hparams: HyperParameters) -> None:
|
||||
"""
|
||||
All initialization necessary for the loss step.
|
||||
"""
|
||||
...
|
||||
|
||||
def init_inference(self, hparams: HyperParameters) -> None:
|
||||
"""
|
||||
All initialization necessary for the inference step.
|
||||
"""
|
||||
...
|
||||
|
||||
# Empty Steps
|
||||
def components(self):
|
||||
@ -162,7 +168,8 @@ class BaseYArchitecture(pl.LightningModule):
|
||||
The backbone step receives the data batch and the components.
|
||||
It can transform both by an arbitrary function.
|
||||
|
||||
It returns the transformed batch and components, each of the same length as the original input.
|
||||
It returns the transformed batch and components,
|
||||
each of the same length as the original input.
|
||||
"""
|
||||
return batch, components
|
||||
|
||||
@ -211,6 +218,9 @@ class BaseYArchitecture(pl.LightningModule):
|
||||
step: str = Steps.TRAINING,
|
||||
**metric_kwargs,
|
||||
):
|
||||
'''
|
||||
Register a callback for evaluating a torchmetric.
|
||||
'''
|
||||
if step == Steps.PREDICT:
|
||||
raise ValueError("Prediction metrics are not supported.")
|
||||
|
||||
@ -224,7 +234,7 @@ class BaseYArchitecture(pl.LightningModule):
|
||||
# Prediction Metrics
|
||||
preds = self(batch)
|
||||
|
||||
x, y = batch
|
||||
_, y = batch
|
||||
for metric in self.registered_metrics[step]:
|
||||
instance = self.registered_metrics[step][metric].to(self.device)
|
||||
instance(y, preds)
|
||||
@ -247,7 +257,7 @@ class BaseYArchitecture(pl.LightningModule):
|
||||
|
||||
return self.loss_forward(batch)
|
||||
|
||||
def training_epoch_end(self, outs) -> None:
|
||||
def training_epoch_end(self, outputs) -> None:
|
||||
self.update_metrics_epoch(Steps.TRAINING)
|
||||
|
||||
# >>>> Validation
|
||||
@ -256,7 +266,7 @@ class BaseYArchitecture(pl.LightningModule):
|
||||
|
||||
return self.loss_forward(batch)
|
||||
|
||||
def validation_epoch_end(self, outs) -> None:
|
||||
def validation_epoch_end(self, outputs) -> None:
|
||||
self.update_metrics_epoch(Steps.VALIDATION)
|
||||
|
||||
# >>>> Test
|
||||
@ -264,7 +274,7 @@ class BaseYArchitecture(pl.LightningModule):
|
||||
self.update_metrics_step(batch, Steps.TEST)
|
||||
return self.loss_forward(batch)
|
||||
|
||||
def test_epoch_end(self, outs) -> None:
|
||||
def test_epoch_end(self, outputs) -> None:
|
||||
self.update_metrics_epoch(Steps.TEST)
|
||||
|
||||
# >>>> Prediction
|
||||
|
@ -1,4 +1,3 @@
|
||||
import logging
|
||||
import warnings
|
||||
from typing import Optional, Type
|
||||
|
||||
@ -36,12 +35,14 @@ class LogTorchmetricCallback(pl.Callback):
|
||||
name,
|
||||
metric: Type[torchmetrics.Metric],
|
||||
step: str = Steps.TRAINING,
|
||||
on_epoch=True,
|
||||
**metric_kwargs,
|
||||
) -> None:
|
||||
self.name = name
|
||||
self.metric = metric
|
||||
self.metric_kwargs = metric_kwargs
|
||||
self.step = step
|
||||
self.on_epoch = on_epoch
|
||||
|
||||
def setup(
|
||||
self,
|
||||
@ -57,7 +58,12 @@ class LogTorchmetricCallback(pl.Callback):
|
||||
)
|
||||
|
||||
def __call__(self, value, pl_module: BaseYArchitecture):
|
||||
pl_module.log(self.name, value)
|
||||
pl_module.log(
|
||||
self.name,
|
||||
value,
|
||||
on_epoch=self.on_epoch,
|
||||
on_step=(not self.on_epoch),
|
||||
)
|
||||
|
||||
|
||||
class LogConfusionMatrix(LogTorchmetricCallback):
|
||||
|
Loading…
Reference in New Issue
Block a user