prototorch_models/prototorch/y/architectures/base.py

281 lines
8.5 KiB
Python
Raw Normal View History

2022-05-17 14:25:43 +00:00
"""
2022-05-19 14:13:08 +00:00
Proto Y Architecture
2022-05-17 14:25:43 +00:00
2022-05-19 14:13:08 +00:00
Network architecture for Component based Learning.
2022-05-17 14:25:43 +00:00
"""
from __future__ import annotations
from dataclasses import asdict, dataclass
from typing import Any, Callable
2022-05-17 14:25:43 +00:00
import pytorch_lightning as pl
import torch
2022-05-18 13:43:09 +00:00
from torchmetrics import Metric
2022-05-17 14:25:43 +00:00
class Steps(enumerate):
TRAINING = "training"
VALIDATION = "validation"
TEST = "test"
PREDICT = "predict"
class BaseYArchitecture(pl.LightningModule):
2022-05-17 15:25:51 +00:00
@dataclass
class HyperParameters:
"""
Add all hyperparameters in the inherited class.
"""
2022-05-17 15:25:51 +00:00
...
2022-06-09 12:55:59 +00:00
# Fields
registered_metrics: dict[str, dict[type[Metric], Metric]] = {
Steps.TRAINING: {},
Steps.VALIDATION: {},
Steps.TEST: {},
}
registered_metric_callbacks: dict[str, dict[type[Metric],
set[Callable]]] = {
Steps.TRAINING: {},
Steps.VALIDATION: {},
Steps.TEST: {},
}
2022-05-17 14:25:43 +00:00
2022-06-09 12:55:59 +00:00
# Type Hints for Necessary Fields
2022-05-18 13:43:09 +00:00
components_layer: torch.nn.Module
2022-05-17 15:25:51 +00:00
2022-05-17 14:25:43 +00:00
def __init__(self, hparams) -> None:
if type(hparams) is dict:
self.save_hyperparameters(hparams)
# TODO: => Move into Component Child
del hparams["initialized_proto_shape"]
hparams = self.HyperParameters(**hparams)
else:
hparam_dict = asdict(hparams)
hparam_dict["component_initializer"] = None
self.save_hyperparameters(hparam_dict, )
2022-05-17 14:25:43 +00:00
super().__init__()
# Common Steps
self.init_components(hparams)
self.init_backbone(hparams)
2022-05-17 14:25:43 +00:00
self.init_comparison(hparams)
self.init_competition(hparams)
# Train Steps
self.init_loss(hparams)
# Inference Steps
self.init_inference(hparams)
# external API
2022-05-18 13:43:09 +00:00
def get_competition(self, batch, components):
latent_batch, latent_components = self.backbone(batch, components)
2022-05-17 14:25:43 +00:00
# TODO: => Latent Hook
comparison_tensor = self.comparison(latent_batch, latent_components)
# TODO: => Comparison Hook
return comparison_tensor
def forward(self, batch):
if isinstance(batch, torch.Tensor):
batch = (batch, None)
# TODO: manage different datatypes?
components = self.components_layer()
# TODO: => Component Hook
2022-05-18 13:43:09 +00:00
comparison_tensor = self.get_competition(batch, components)
2022-05-17 14:25:43 +00:00
# TODO: => Competition Hook
return self.inference(comparison_tensor, components)
def predict(self, batch):
"""
Alias for forward
"""
return self.forward(batch)
def forward_comparison(self, batch):
if isinstance(batch, torch.Tensor):
batch = (batch, None)
# TODO: manage different datatypes?
components = self.components_layer()
# TODO: => Component Hook
2022-05-18 13:43:09 +00:00
return self.get_competition(batch, components)
2022-05-17 14:25:43 +00:00
def loss_forward(self, batch):
# TODO: manage different datatypes?
components = self.components_layer()
# TODO: => Component Hook
2022-05-18 13:43:09 +00:00
comparison_tensor = self.get_competition(batch, components)
2022-05-17 14:25:43 +00:00
# TODO: => Competition Hook
return self.loss(comparison_tensor, batch, components)
# Empty Initialization
2022-05-17 15:25:51 +00:00
def init_components(self, hparams: HyperParameters) -> None:
"""
All initialization necessary for the components step.
"""
2022-05-17 14:25:43 +00:00
...
def init_backbone(self, hparams: HyperParameters) -> None:
"""
All initialization necessary for the backbone step.
"""
2022-05-17 14:25:43 +00:00
...
2022-05-17 15:25:51 +00:00
def init_comparison(self, hparams: HyperParameters) -> None:
"""
All initialization necessary for the comparison step.
"""
2022-05-17 14:25:43 +00:00
...
2022-05-17 15:25:51 +00:00
def init_competition(self, hparams: HyperParameters) -> None:
"""
All initialization necessary for the competition step.
"""
2022-05-17 14:25:43 +00:00
...
2022-05-17 15:25:51 +00:00
def init_loss(self, hparams: HyperParameters) -> None:
"""
All initialization necessary for the loss step.
"""
2022-05-17 14:25:43 +00:00
...
2022-05-17 15:25:51 +00:00
def init_inference(self, hparams: HyperParameters) -> None:
"""
All initialization necessary for the inference step.
"""
2022-05-17 14:25:43 +00:00
...
# Empty Steps
def components(self):
"""
This step has no input.
It returns the components.
"""
raise NotImplementedError(
"The components step has no reasonable default.")
def backbone(self, batch, components):
2022-05-17 14:25:43 +00:00
"""
The backbone step receives the data batch and the components.
2022-05-17 14:25:43 +00:00
It can transform both by an arbitrary function.
It returns the transformed batch and components, each of the same length as the original input.
"""
return batch, components
def comparison(self, batch, components):
"""
2022-05-18 13:43:09 +00:00
Takes a batch of size N and the component set of size M.
2022-05-17 14:25:43 +00:00
It returns an NxMxD tensor containing D (usually 1) pairwise comparison measures.
"""
raise NotImplementedError(
"The comparison step has no reasonable default.")
2022-05-18 13:43:09 +00:00
def competition(self, comparison_measures, components):
2022-05-17 14:25:43 +00:00
"""
Takes the tensor of comparison measures.
Assigns a competition vector to each class.
"""
raise NotImplementedError(
"The competition step has no reasonable default.")
2022-05-18 13:43:09 +00:00
def loss(self, comparison_measures, batch, components):
2022-05-17 14:25:43 +00:00
"""
Takes the tensor of competition measures.
Calculates a single loss value
"""
raise NotImplementedError("The loss step has no reasonable default.")
2022-05-18 13:43:09 +00:00
def inference(self, comparison_measures, components):
2022-05-17 14:25:43 +00:00
"""
Takes the tensor of competition measures.
Returns the inferred vector.
"""
raise NotImplementedError(
"The inference step has no reasonable default.")
2022-06-09 12:55:59 +00:00
# Y Architecture Hooks
2022-05-17 14:25:43 +00:00
2022-06-09 12:55:59 +00:00
# internal API, called by models and callbacks
def register_torchmetric(
self,
name: Callable,
metric: type[Metric],
step: str = Steps.TRAINING,
2022-06-09 12:55:59 +00:00
**metric_kwargs,
):
if step == Steps.PREDICT:
raise ValueError("Prediction metrics are not supported.")
2022-06-09 12:55:59 +00:00
if metric not in self.registered_metrics:
self.registered_metrics[step][metric] = metric(**metric_kwargs)
self.registered_metric_callbacks[step][metric] = {name}
2022-06-09 12:55:59 +00:00
else:
self.registered_metric_callbacks[step][metric].add(name)
2022-06-09 12:55:59 +00:00
def update_metrics_step(self, batch, step):
2022-05-17 14:25:43 +00:00
# Prediction Metrics
2022-06-09 12:55:59 +00:00
preds = self(batch)
x, y = batch
for metric in self.registered_metrics[step]:
instance = self.registered_metrics[step][metric].to(self.device)
2022-05-17 14:25:43 +00:00
instance(y, preds)
def update_metrics_epoch(self, step):
for metric in self.registered_metrics[step]:
instance = self.registered_metrics[step][metric].to(self.device)
2022-05-17 14:25:43 +00:00
value = instance.compute()
for callback in self.registered_metric_callbacks[step][metric]:
2022-06-09 12:55:59 +00:00
callback(value, self)
2022-05-17 14:25:43 +00:00
instance.reset()
# Lightning steps
# -------------------------------------------------------------------------
# >>>> Training
2022-05-17 14:25:43 +00:00
def training_step(self, batch, batch_idx, optimizer_idx=None):
self.update_metrics_step(batch, Steps.TRAINING)
2022-05-17 14:25:43 +00:00
return self.loss_forward(batch)
def training_epoch_end(self, outs) -> None:
self.update_metrics_epoch(Steps.TRAINING)
# >>>> Validation
2022-05-17 14:25:43 +00:00
def validation_step(self, batch, batch_idx):
self.update_metrics_step(batch, Steps.VALIDATION)
2022-05-17 14:25:43 +00:00
return self.loss_forward(batch)
def validation_epoch_end(self, outs) -> None:
self.update_metrics_epoch(Steps.VALIDATION)
# >>>> Test
2022-05-17 14:25:43 +00:00
def test_step(self, batch, batch_idx):
self.update_metrics_step(batch, Steps.TEST)
2022-05-17 14:25:43 +00:00
return self.loss_forward(batch)
2022-06-09 12:55:59 +00:00
def test_epoch_end(self, outs) -> None:
self.update_metrics_epoch(Steps.TEST)
# >>>> Prediction
def predict_step(self, batch, batch_idx, dataloader_idx=0):
return self.predict(batch)
# Check points
def on_save_checkpoint(self, checkpoint: dict[str, Any]) -> None:
# Compatible with Lightning
checkpoint["hyper_parameters"] = {
'hparams': checkpoint["hyper_parameters"]
}
return super().on_save_checkpoint(checkpoint)