feat: metric callback defaults on epoch
This commit is contained in:
parent
c3cad19853
commit
16ca409f07
@ -46,15 +46,15 @@ class BaseYArchitecture(pl.LightningModule):
|
|||||||
components_layer: torch.nn.Module
|
components_layer: torch.nn.Module
|
||||||
|
|
||||||
def __init__(self, hparams) -> None:
|
def __init__(self, hparams) -> None:
|
||||||
if type(hparams) is dict:
|
if isinstance(hparams, dict):
|
||||||
self.save_hyperparameters(hparams)
|
self.save_hyperparameters(hparams)
|
||||||
# TODO: => Move into Component Child
|
# TODO: => Move into Component Child
|
||||||
del hparams["initialized_proto_shape"]
|
del hparams["initialized_proto_shape"]
|
||||||
hparams = self.HyperParameters(**hparams)
|
hparams = self.HyperParameters(**hparams)
|
||||||
else:
|
else:
|
||||||
hparam_dict = asdict(hparams)
|
hparams_dict = asdict(hparams)
|
||||||
hparam_dict["component_initializer"] = None
|
hparams_dict["component_initializer"] = None
|
||||||
self.save_hyperparameters(hparam_dict, )
|
self.save_hyperparameters(hparams_dict, )
|
||||||
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@ -72,6 +72,9 @@ class BaseYArchitecture(pl.LightningModule):
|
|||||||
|
|
||||||
# external API
|
# external API
|
||||||
def get_competition(self, batch, components):
|
def get_competition(self, batch, components):
|
||||||
|
'''
|
||||||
|
Returns the output of the competition layer.
|
||||||
|
'''
|
||||||
latent_batch, latent_components = self.backbone(batch, components)
|
latent_batch, latent_components = self.backbone(batch, components)
|
||||||
# TODO: => Latent Hook
|
# TODO: => Latent Hook
|
||||||
comparison_tensor = self.comparison(latent_batch, latent_components)
|
comparison_tensor = self.comparison(latent_batch, latent_components)
|
||||||
@ -79,6 +82,9 @@ class BaseYArchitecture(pl.LightningModule):
|
|||||||
return comparison_tensor
|
return comparison_tensor
|
||||||
|
|
||||||
def forward(self, batch):
|
def forward(self, batch):
|
||||||
|
'''
|
||||||
|
Returns the prediction.
|
||||||
|
'''
|
||||||
if isinstance(batch, torch.Tensor):
|
if isinstance(batch, torch.Tensor):
|
||||||
batch = (batch, None)
|
batch = (batch, None)
|
||||||
# TODO: manage different datatypes?
|
# TODO: manage different datatypes?
|
||||||
@ -95,6 +101,9 @@ class BaseYArchitecture(pl.LightningModule):
|
|||||||
return self.forward(batch)
|
return self.forward(batch)
|
||||||
|
|
||||||
def forward_comparison(self, batch):
|
def forward_comparison(self, batch):
|
||||||
|
'''
|
||||||
|
Returns the Output of the comparison layer.
|
||||||
|
'''
|
||||||
if isinstance(batch, torch.Tensor):
|
if isinstance(batch, torch.Tensor):
|
||||||
batch = (batch, None)
|
batch = (batch, None)
|
||||||
# TODO: manage different datatypes?
|
# TODO: manage different datatypes?
|
||||||
@ -103,6 +112,9 @@ class BaseYArchitecture(pl.LightningModule):
|
|||||||
return self.get_competition(batch, components)
|
return self.get_competition(batch, components)
|
||||||
|
|
||||||
def loss_forward(self, batch):
|
def loss_forward(self, batch):
|
||||||
|
'''
|
||||||
|
Returns the output of the loss layer.
|
||||||
|
'''
|
||||||
# TODO: manage different datatypes?
|
# TODO: manage different datatypes?
|
||||||
components = self.components_layer()
|
components = self.components_layer()
|
||||||
# TODO: => Component Hook
|
# TODO: => Component Hook
|
||||||
@ -115,37 +127,31 @@ class BaseYArchitecture(pl.LightningModule):
|
|||||||
"""
|
"""
|
||||||
All initialization necessary for the components step.
|
All initialization necessary for the components step.
|
||||||
"""
|
"""
|
||||||
...
|
|
||||||
|
|
||||||
def init_backbone(self, hparams: HyperParameters) -> None:
|
def init_backbone(self, hparams: HyperParameters) -> None:
|
||||||
"""
|
"""
|
||||||
All initialization necessary for the backbone step.
|
All initialization necessary for the backbone step.
|
||||||
"""
|
"""
|
||||||
...
|
|
||||||
|
|
||||||
def init_comparison(self, hparams: HyperParameters) -> None:
|
def init_comparison(self, hparams: HyperParameters) -> None:
|
||||||
"""
|
"""
|
||||||
All initialization necessary for the comparison step.
|
All initialization necessary for the comparison step.
|
||||||
"""
|
"""
|
||||||
...
|
|
||||||
|
|
||||||
def init_competition(self, hparams: HyperParameters) -> None:
|
def init_competition(self, hparams: HyperParameters) -> None:
|
||||||
"""
|
"""
|
||||||
All initialization necessary for the competition step.
|
All initialization necessary for the competition step.
|
||||||
"""
|
"""
|
||||||
...
|
|
||||||
|
|
||||||
def init_loss(self, hparams: HyperParameters) -> None:
|
def init_loss(self, hparams: HyperParameters) -> None:
|
||||||
"""
|
"""
|
||||||
All initialization necessary for the loss step.
|
All initialization necessary for the loss step.
|
||||||
"""
|
"""
|
||||||
...
|
|
||||||
|
|
||||||
def init_inference(self, hparams: HyperParameters) -> None:
|
def init_inference(self, hparams: HyperParameters) -> None:
|
||||||
"""
|
"""
|
||||||
All initialization necessary for the inference step.
|
All initialization necessary for the inference step.
|
||||||
"""
|
"""
|
||||||
...
|
|
||||||
|
|
||||||
# Empty Steps
|
# Empty Steps
|
||||||
def components(self):
|
def components(self):
|
||||||
@ -162,7 +168,8 @@ class BaseYArchitecture(pl.LightningModule):
|
|||||||
The backbone step receives the data batch and the components.
|
The backbone step receives the data batch and the components.
|
||||||
It can transform both by an arbitrary function.
|
It can transform both by an arbitrary function.
|
||||||
|
|
||||||
It returns the transformed batch and components, each of the same length as the original input.
|
It returns the transformed batch and components,
|
||||||
|
each of the same length as the original input.
|
||||||
"""
|
"""
|
||||||
return batch, components
|
return batch, components
|
||||||
|
|
||||||
@ -211,6 +218,9 @@ class BaseYArchitecture(pl.LightningModule):
|
|||||||
step: str = Steps.TRAINING,
|
step: str = Steps.TRAINING,
|
||||||
**metric_kwargs,
|
**metric_kwargs,
|
||||||
):
|
):
|
||||||
|
'''
|
||||||
|
Register a callback for evaluating a torchmetric.
|
||||||
|
'''
|
||||||
if step == Steps.PREDICT:
|
if step == Steps.PREDICT:
|
||||||
raise ValueError("Prediction metrics are not supported.")
|
raise ValueError("Prediction metrics are not supported.")
|
||||||
|
|
||||||
@ -224,7 +234,7 @@ class BaseYArchitecture(pl.LightningModule):
|
|||||||
# Prediction Metrics
|
# Prediction Metrics
|
||||||
preds = self(batch)
|
preds = self(batch)
|
||||||
|
|
||||||
x, y = batch
|
_, y = batch
|
||||||
for metric in self.registered_metrics[step]:
|
for metric in self.registered_metrics[step]:
|
||||||
instance = self.registered_metrics[step][metric].to(self.device)
|
instance = self.registered_metrics[step][metric].to(self.device)
|
||||||
instance(y, preds)
|
instance(y, preds)
|
||||||
@ -247,7 +257,7 @@ class BaseYArchitecture(pl.LightningModule):
|
|||||||
|
|
||||||
return self.loss_forward(batch)
|
return self.loss_forward(batch)
|
||||||
|
|
||||||
def training_epoch_end(self, outs) -> None:
|
def training_epoch_end(self, outputs) -> None:
|
||||||
self.update_metrics_epoch(Steps.TRAINING)
|
self.update_metrics_epoch(Steps.TRAINING)
|
||||||
|
|
||||||
# >>>> Validation
|
# >>>> Validation
|
||||||
@ -256,7 +266,7 @@ class BaseYArchitecture(pl.LightningModule):
|
|||||||
|
|
||||||
return self.loss_forward(batch)
|
return self.loss_forward(batch)
|
||||||
|
|
||||||
def validation_epoch_end(self, outs) -> None:
|
def validation_epoch_end(self, outputs) -> None:
|
||||||
self.update_metrics_epoch(Steps.VALIDATION)
|
self.update_metrics_epoch(Steps.VALIDATION)
|
||||||
|
|
||||||
# >>>> Test
|
# >>>> Test
|
||||||
@ -264,7 +274,7 @@ class BaseYArchitecture(pl.LightningModule):
|
|||||||
self.update_metrics_step(batch, Steps.TEST)
|
self.update_metrics_step(batch, Steps.TEST)
|
||||||
return self.loss_forward(batch)
|
return self.loss_forward(batch)
|
||||||
|
|
||||||
def test_epoch_end(self, outs) -> None:
|
def test_epoch_end(self, outputs) -> None:
|
||||||
self.update_metrics_epoch(Steps.TEST)
|
self.update_metrics_epoch(Steps.TEST)
|
||||||
|
|
||||||
# >>>> Prediction
|
# >>>> Prediction
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
import logging
|
|
||||||
import warnings
|
import warnings
|
||||||
from typing import Optional, Type
|
from typing import Optional, Type
|
||||||
|
|
||||||
@ -36,12 +35,14 @@ class LogTorchmetricCallback(pl.Callback):
|
|||||||
name,
|
name,
|
||||||
metric: Type[torchmetrics.Metric],
|
metric: Type[torchmetrics.Metric],
|
||||||
step: str = Steps.TRAINING,
|
step: str = Steps.TRAINING,
|
||||||
|
on_epoch=True,
|
||||||
**metric_kwargs,
|
**metric_kwargs,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.name = name
|
self.name = name
|
||||||
self.metric = metric
|
self.metric = metric
|
||||||
self.metric_kwargs = metric_kwargs
|
self.metric_kwargs = metric_kwargs
|
||||||
self.step = step
|
self.step = step
|
||||||
|
self.on_epoch = on_epoch
|
||||||
|
|
||||||
def setup(
|
def setup(
|
||||||
self,
|
self,
|
||||||
@ -57,7 +58,12 @@ class LogTorchmetricCallback(pl.Callback):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def __call__(self, value, pl_module: BaseYArchitecture):
|
def __call__(self, value, pl_module: BaseYArchitecture):
|
||||||
pl_module.log(self.name, value)
|
pl_module.log(
|
||||||
|
self.name,
|
||||||
|
value,
|
||||||
|
on_epoch=self.on_epoch,
|
||||||
|
on_step=(not self.on_epoch),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class LogConfusionMatrix(LogTorchmetricCallback):
|
class LogConfusionMatrix(LogTorchmetricCallback):
|
||||||
|
Loading…
Reference in New Issue
Block a user