prototorch_models/prototorch/models/architectures/base.py

291 lines
8.8 KiB
Python
Raw Normal View History

2022-05-17 14:25:43 +00:00
"""
2022-05-19 14:13:08 +00:00
Proto Y Architecture
2022-05-17 14:25:43 +00:00
2022-05-19 14:13:08 +00:00
Network architecture for Component based Learning.
2022-05-17 14:25:43 +00:00
"""
from __future__ import annotations
from dataclasses import asdict, dataclass
from typing import Any, Callable
2022-05-17 14:25:43 +00:00
import pytorch_lightning as pl
import torch
2022-05-18 13:43:09 +00:00
from torchmetrics import Metric
2022-05-17 14:25:43 +00:00
class Steps(enumerate):
TRAINING = "training"
VALIDATION = "validation"
TEST = "test"
PREDICT = "predict"
class BaseYArchitecture(pl.LightningModule):
2022-05-17 15:25:51 +00:00
@dataclass
class HyperParameters:
"""
Add all hyperparameters in the inherited class.
"""
2022-05-17 15:25:51 +00:00
...
2022-06-09 12:55:59 +00:00
# Fields
registered_metrics: dict[str, dict[type[Metric], Metric]] = {
Steps.TRAINING: {},
Steps.VALIDATION: {},
Steps.TEST: {},
}
registered_metric_callbacks: dict[str, dict[type[Metric],
set[Callable]]] = {
Steps.TRAINING: {},
Steps.VALIDATION: {},
Steps.TEST: {},
}
2022-05-17 14:25:43 +00:00
2022-06-09 12:55:59 +00:00
# Type Hints for Necessary Fields
2022-05-18 13:43:09 +00:00
components_layer: torch.nn.Module
2022-05-17 15:25:51 +00:00
2022-05-17 14:25:43 +00:00
def __init__(self, hparams) -> None:
if isinstance(hparams, dict):
self.save_hyperparameters(hparams)
# TODO: => Move into Component Child
del hparams["initialized_proto_shape"]
hparams = self.HyperParameters(**hparams)
else:
hparams_dict = asdict(hparams)
hparams_dict["component_initializer"] = None
self.save_hyperparameters(hparams_dict, )
2022-05-17 14:25:43 +00:00
super().__init__()
# Common Steps
self.init_components(hparams)
self.init_backbone(hparams)
2022-05-17 14:25:43 +00:00
self.init_comparison(hparams)
self.init_competition(hparams)
# Train Steps
self.init_loss(hparams)
# Inference Steps
self.init_inference(hparams)
# external API
2022-05-18 13:43:09 +00:00
def get_competition(self, batch, components):
'''
Returns the output of the competition layer.
'''
latent_batch, latent_components = self.backbone(batch, components)
2022-05-17 14:25:43 +00:00
# TODO: => Latent Hook
comparison_tensor = self.comparison(latent_batch, latent_components)
# TODO: => Comparison Hook
return comparison_tensor
def forward(self, batch):
'''
Returns the prediction.
'''
2022-05-17 14:25:43 +00:00
if isinstance(batch, torch.Tensor):
batch = (batch, None)
# TODO: manage different datatypes?
components = self.components_layer()
# TODO: => Component Hook
2022-05-18 13:43:09 +00:00
comparison_tensor = self.get_competition(batch, components)
2022-05-17 14:25:43 +00:00
# TODO: => Competition Hook
return self.inference(comparison_tensor, components)
def predict(self, batch):
"""
Alias for forward
"""
return self.forward(batch)
def forward_comparison(self, batch):
'''
Returns the Output of the comparison layer.
'''
2022-05-17 14:25:43 +00:00
if isinstance(batch, torch.Tensor):
batch = (batch, None)
# TODO: manage different datatypes?
components = self.components_layer()
# TODO: => Component Hook
2022-05-18 13:43:09 +00:00
return self.get_competition(batch, components)
2022-05-17 14:25:43 +00:00
def loss_forward(self, batch):
'''
Returns the output of the loss layer.
'''
2022-05-17 14:25:43 +00:00
# TODO: manage different datatypes?
components = self.components_layer()
# TODO: => Component Hook
2022-05-18 13:43:09 +00:00
comparison_tensor = self.get_competition(batch, components)
2022-05-17 14:25:43 +00:00
# TODO: => Competition Hook
return self.loss(comparison_tensor, batch, components)
# Empty Initialization
2022-05-17 15:25:51 +00:00
def init_components(self, hparams: HyperParameters) -> None:
"""
All initialization necessary for the components step.
"""
2022-05-17 14:25:43 +00:00
def init_backbone(self, hparams: HyperParameters) -> None:
"""
All initialization necessary for the backbone step.
"""
2022-05-17 14:25:43 +00:00
2022-05-17 15:25:51 +00:00
def init_comparison(self, hparams: HyperParameters) -> None:
"""
All initialization necessary for the comparison step.
"""
2022-05-17 14:25:43 +00:00
2022-05-17 15:25:51 +00:00
def init_competition(self, hparams: HyperParameters) -> None:
"""
All initialization necessary for the competition step.
"""
2022-05-17 14:25:43 +00:00
2022-05-17 15:25:51 +00:00
def init_loss(self, hparams: HyperParameters) -> None:
"""
All initialization necessary for the loss step.
"""
2022-05-17 14:25:43 +00:00
2022-05-17 15:25:51 +00:00
def init_inference(self, hparams: HyperParameters) -> None:
"""
All initialization necessary for the inference step.
"""
2022-05-17 14:25:43 +00:00
# Empty Steps
def components(self):
"""
This step has no input.
It returns the components.
"""
raise NotImplementedError(
"The components step has no reasonable default.")
def backbone(self, batch, components):
2022-05-17 14:25:43 +00:00
"""
The backbone step receives the data batch and the components.
2022-05-17 14:25:43 +00:00
It can transform both by an arbitrary function.
It returns the transformed batch and components,
each of the same length as the original input.
2022-05-17 14:25:43 +00:00
"""
return batch, components
def comparison(self, batch, components):
"""
2022-05-18 13:43:09 +00:00
Takes a batch of size N and the component set of size M.
2022-05-17 14:25:43 +00:00
It returns an NxMxD tensor containing D (usually 1) pairwise comparison measures.
"""
raise NotImplementedError(
"The comparison step has no reasonable default.")
2022-05-18 13:43:09 +00:00
def competition(self, comparison_measures, components):
2022-05-17 14:25:43 +00:00
"""
Takes the tensor of comparison measures.
Assigns a competition vector to each class.
"""
raise NotImplementedError(
"The competition step has no reasonable default.")
2022-05-18 13:43:09 +00:00
def loss(self, comparison_measures, batch, components):
2022-05-17 14:25:43 +00:00
"""
Takes the tensor of competition measures.
Calculates a single loss value
"""
raise NotImplementedError("The loss step has no reasonable default.")
2022-05-18 13:43:09 +00:00
def inference(self, comparison_measures, components):
2022-05-17 14:25:43 +00:00
"""
Takes the tensor of competition measures.
Returns the inferred vector.
"""
raise NotImplementedError(
"The inference step has no reasonable default.")
2022-06-09 12:55:59 +00:00
# Y Architecture Hooks
2022-05-17 14:25:43 +00:00
2022-06-09 12:55:59 +00:00
# internal API, called by models and callbacks
def register_torchmetric(
self,
name: Callable,
metric: type[Metric],
step: str = Steps.TRAINING,
2022-06-09 12:55:59 +00:00
**metric_kwargs,
):
'''
Register a callback for evaluating a torchmetric.
'''
if step == Steps.PREDICT:
raise ValueError("Prediction metrics are not supported.")
2022-06-09 12:55:59 +00:00
if metric not in self.registered_metrics:
self.registered_metrics[step][metric] = metric(**metric_kwargs)
self.registered_metric_callbacks[step][metric] = {name}
2022-06-09 12:55:59 +00:00
else:
self.registered_metric_callbacks[step][metric].add(name)
2022-06-09 12:55:59 +00:00
def update_metrics_step(self, batch, step):
2022-05-17 14:25:43 +00:00
# Prediction Metrics
2022-06-09 12:55:59 +00:00
preds = self(batch)
_, y = batch
for metric in self.registered_metrics[step]:
instance = self.registered_metrics[step][metric].to(self.device)
2022-09-21 08:22:35 +00:00
instance(y, preds.reshape(y.shape))
2022-05-17 14:25:43 +00:00
def update_metrics_epoch(self, step):
for metric in self.registered_metrics[step]:
instance = self.registered_metrics[step][metric].to(self.device)
2022-05-17 14:25:43 +00:00
value = instance.compute()
for callback in self.registered_metric_callbacks[step][metric]:
2022-06-09 12:55:59 +00:00
callback(value, self)
2022-05-17 14:25:43 +00:00
instance.reset()
# Lightning steps
# -------------------------------------------------------------------------
# >>>> Training
2022-05-17 14:25:43 +00:00
def training_step(self, batch, batch_idx, optimizer_idx=None):
self.update_metrics_step(batch, Steps.TRAINING)
2022-05-17 14:25:43 +00:00
return self.loss_forward(batch)
def training_epoch_end(self, outputs) -> None:
self.update_metrics_epoch(Steps.TRAINING)
# >>>> Validation
2022-05-17 14:25:43 +00:00
def validation_step(self, batch, batch_idx):
self.update_metrics_step(batch, Steps.VALIDATION)
2022-05-17 14:25:43 +00:00
return self.loss_forward(batch)
def validation_epoch_end(self, outputs) -> None:
self.update_metrics_epoch(Steps.VALIDATION)
# >>>> Test
2022-05-17 14:25:43 +00:00
def test_step(self, batch, batch_idx):
self.update_metrics_step(batch, Steps.TEST)
2022-05-17 14:25:43 +00:00
return self.loss_forward(batch)
2022-06-09 12:55:59 +00:00
def test_epoch_end(self, outputs) -> None:
self.update_metrics_epoch(Steps.TEST)
# >>>> Prediction
def predict_step(self, batch, batch_idx, dataloader_idx=0):
return self.predict(batch)
# Check points
def on_save_checkpoint(self, checkpoint: dict[str, Any]) -> None:
# Compatible with Lightning
checkpoint["hyper_parameters"] = {
'hparams': checkpoint["hyper_parameters"]
}
return super().on_save_checkpoint(checkpoint)