feat: remove old architecture

This commit is contained in:
Alexander Engelsberger
2022-08-15 12:14:14 +02:00
parent bcf9c6bdb1
commit 5a89f24c10
44 changed files with 371 additions and 3757 deletions

View File

@@ -0,0 +1,280 @@
"""
Proto Y Architecture
Network architecture for Component based Learning.
"""
from __future__ import annotations
from dataclasses import asdict, dataclass
from typing import Any, Callable
import pytorch_lightning as pl
import torch
from torchmetrics import Metric
class Steps(enumerate):
TRAINING = "training"
VALIDATION = "validation"
TEST = "test"
PREDICT = "predict"
class BaseYArchitecture(pl.LightningModule):
@dataclass
class HyperParameters:
"""
Add all hyperparameters in the inherited class.
"""
...
# Fields
registered_metrics: dict[str, dict[type[Metric], Metric]] = {
Steps.TRAINING: {},
Steps.VALIDATION: {},
Steps.TEST: {},
}
registered_metric_callbacks: dict[str, dict[type[Metric],
set[Callable]]] = {
Steps.TRAINING: {},
Steps.VALIDATION: {},
Steps.TEST: {},
}
# Type Hints for Necessary Fields
components_layer: torch.nn.Module
def __init__(self, hparams) -> None:
if type(hparams) is dict:
self.save_hyperparameters(hparams)
# TODO: => Move into Component Child
del hparams["initialized_proto_shape"]
hparams = self.HyperParameters(**hparams)
else:
hparam_dict = asdict(hparams)
hparam_dict["component_initializer"] = None
self.save_hyperparameters(hparam_dict, )
super().__init__()
# Common Steps
self.init_components(hparams)
self.init_backbone(hparams)
self.init_comparison(hparams)
self.init_competition(hparams)
# Train Steps
self.init_loss(hparams)
# Inference Steps
self.init_inference(hparams)
# external API
def get_competition(self, batch, components):
latent_batch, latent_components = self.backbone(batch, components)
# TODO: => Latent Hook
comparison_tensor = self.comparison(latent_batch, latent_components)
# TODO: => Comparison Hook
return comparison_tensor
def forward(self, batch):
if isinstance(batch, torch.Tensor):
batch = (batch, None)
# TODO: manage different datatypes?
components = self.components_layer()
# TODO: => Component Hook
comparison_tensor = self.get_competition(batch, components)
# TODO: => Competition Hook
return self.inference(comparison_tensor, components)
def predict(self, batch):
"""
Alias for forward
"""
return self.forward(batch)
def forward_comparison(self, batch):
if isinstance(batch, torch.Tensor):
batch = (batch, None)
# TODO: manage different datatypes?
components = self.components_layer()
# TODO: => Component Hook
return self.get_competition(batch, components)
def loss_forward(self, batch):
# TODO: manage different datatypes?
components = self.components_layer()
# TODO: => Component Hook
comparison_tensor = self.get_competition(batch, components)
# TODO: => Competition Hook
return self.loss(comparison_tensor, batch, components)
# Empty Initialization
def init_components(self, hparams: HyperParameters) -> None:
"""
All initialization necessary for the components step.
"""
...
def init_backbone(self, hparams: HyperParameters) -> None:
"""
All initialization necessary for the backbone step.
"""
...
def init_comparison(self, hparams: HyperParameters) -> None:
"""
All initialization necessary for the comparison step.
"""
...
def init_competition(self, hparams: HyperParameters) -> None:
"""
All initialization necessary for the competition step.
"""
...
def init_loss(self, hparams: HyperParameters) -> None:
"""
All initialization necessary for the loss step.
"""
...
def init_inference(self, hparams: HyperParameters) -> None:
"""
All initialization necessary for the inference step.
"""
...
# Empty Steps
def components(self):
"""
This step has no input.
It returns the components.
"""
raise NotImplementedError(
"The components step has no reasonable default.")
def backbone(self, batch, components):
"""
The backbone step receives the data batch and the components.
It can transform both by an arbitrary function.
It returns the transformed batch and components, each of the same length as the original input.
"""
return batch, components
def comparison(self, batch, components):
"""
Takes a batch of size N and the component set of size M.
It returns an NxMxD tensor containing D (usually 1) pairwise comparison measures.
"""
raise NotImplementedError(
"The comparison step has no reasonable default.")
def competition(self, comparison_measures, components):
"""
Takes the tensor of comparison measures.
Assigns a competition vector to each class.
"""
raise NotImplementedError(
"The competition step has no reasonable default.")
def loss(self, comparison_measures, batch, components):
"""
Takes the tensor of competition measures.
Calculates a single loss value
"""
raise NotImplementedError("The loss step has no reasonable default.")
def inference(self, comparison_measures, components):
"""
Takes the tensor of competition measures.
Returns the inferred vector.
"""
raise NotImplementedError(
"The inference step has no reasonable default.")
# Y Architecture Hooks
# internal API, called by models and callbacks
def register_torchmetric(
self,
name: Callable,
metric: type[Metric],
step: str = Steps.TRAINING,
**metric_kwargs,
):
if step == Steps.PREDICT:
raise ValueError("Prediction metrics are not supported.")
if metric not in self.registered_metrics:
self.registered_metrics[step][metric] = metric(**metric_kwargs)
self.registered_metric_callbacks[step][metric] = {name}
else:
self.registered_metric_callbacks[step][metric].add(name)
def update_metrics_step(self, batch, step):
# Prediction Metrics
preds = self(batch)
x, y = batch
for metric in self.registered_metrics[step]:
instance = self.registered_metrics[step][metric].to(self.device)
instance(y, preds)
def update_metrics_epoch(self, step):
for metric in self.registered_metrics[step]:
instance = self.registered_metrics[step][metric].to(self.device)
value = instance.compute()
for callback in self.registered_metric_callbacks[step][metric]:
callback(value, self)
instance.reset()
# Lightning steps
# -------------------------------------------------------------------------
# >>>> Training
def training_step(self, batch, batch_idx, optimizer_idx=None):
self.update_metrics_step(batch, Steps.TRAINING)
return self.loss_forward(batch)
def training_epoch_end(self, outs) -> None:
self.update_metrics_epoch(Steps.TRAINING)
# >>>> Validation
def validation_step(self, batch, batch_idx):
self.update_metrics_step(batch, Steps.VALIDATION)
return self.loss_forward(batch)
def validation_epoch_end(self, outs) -> None:
self.update_metrics_epoch(Steps.VALIDATION)
# >>>> Test
def test_step(self, batch, batch_idx):
self.update_metrics_step(batch, Steps.TEST)
return self.loss_forward(batch)
def test_epoch_end(self, outs) -> None:
self.update_metrics_epoch(Steps.TEST)
# >>>> Prediction
def predict_step(self, batch, batch_idx, dataloader_idx=0):
return self.predict(batch)
# Check points
def on_save_checkpoint(self, checkpoint: dict[str, Any]) -> None:
# Compatible with Lightning
checkpoint["hyper_parameters"] = {
'hparams': checkpoint["hyper_parameters"]
}
return super().on_save_checkpoint(checkpoint)

View File

@@ -0,0 +1,137 @@
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Callable
import torch
from prototorch.core.distances import euclidean_distance
from prototorch.core.initializers import (
AbstractLinearTransformInitializer,
EyeLinearTransformInitializer,
)
from prototorch.models.architectures.base import BaseYArchitecture
from prototorch.nn.wrappers import LambdaLayer
from torch import Tensor
from torch.nn.parameter import Parameter
class SimpleComparisonMixin(BaseYArchitecture):
"""
Simple Comparison
A comparison layer that only uses the positions of the components
and the batch for dissimilarity computation.
"""
# HyperParameters
# ----------------------------------------------------------------------------------------------
@dataclass
class HyperParameters(BaseYArchitecture.HyperParameters):
"""
comparison_fn: The comparison / dissimilarity function to use. Default: euclidean_distance.
comparison_args: Keyword arguments for the comparison function. Default: {}.
"""
comparison_fn: Callable = euclidean_distance
comparison_args: dict = field(default_factory=lambda: dict())
comparison_parameters: dict = field(default_factory=lambda: dict())
# Steps
# ----------------------------------------------------------------------------------------------
def init_comparison(self, hparams: HyperParameters):
self.comparison_layer = LambdaLayer(
fn=hparams.comparison_fn,
**hparams.comparison_args,
)
self.comparison_kwargs: dict[str, Tensor] = dict()
def comparison(self, batch, components):
comp_tensor, _ = components
batch_tensor, _ = batch
comp_tensor = comp_tensor.unsqueeze(1)
distances = self.comparison_layer(
batch_tensor,
comp_tensor,
**self.comparison_kwargs,
)
return distances
class OmegaComparisonMixin(SimpleComparisonMixin):
"""
Omega Comparison
A comparison layer that uses the positions of the components
and the batch for dissimilarity computation.
"""
_omega: torch.Tensor
# HyperParameters
# ----------------------------------------------------------------------------------------------
@dataclass
class HyperParameters(SimpleComparisonMixin.HyperParameters):
"""
input_dim: Necessary Field: The dimensionality of the input.
latent_dim:
The dimensionality of the latent space. Default: 2.
omega_initializer:
The initializer to use for the omega matrix. Default: EyeLinearTransformInitializer.
"""
input_dim: int | None = None
latent_dim: int = 2
omega_initializer: type[
AbstractLinearTransformInitializer] = EyeLinearTransformInitializer
# Steps
# ----------------------------------------------------------------------------------------------
def init_comparison(self, hparams: HyperParameters) -> None:
super().init_comparison(hparams)
# Initialize the omega matrix
if hparams.input_dim is None:
raise ValueError("input_dim must be specified.")
else:
omega = hparams.omega_initializer().generate(
hparams.input_dim,
hparams.latent_dim,
)
self.register_parameter("_omega", Parameter(omega))
self.comparison_kwargs = dict(omega=self._omega)
# Properties
# ----------------------------------------------------------------------------------------------
@property
def omega_matrix(self):
'''
Omega Matrix. Mapping applied to data and prototypes.
'''
return self._omega.detach().cpu()
@property
def lambda_matrix(self):
'''
Lambda Matrix.
'''
omega = self._omega.detach()
lam = omega @ omega.T
return lam.detach().cpu()
@property
def relevance_profile(self):
'''
Relevance Profile. Main Diagonal of the Lambda Matrix.
'''
return self.lambda_matrix.diag().abs()
@property
def classification_influence_profile(self):
'''
Classification Influence Profile. Influence of each dimension.
'''
lam = self.lambda_matrix
return lam.abs().sum(0)

View File

@@ -0,0 +1,29 @@
from dataclasses import dataclass
from prototorch.core.competitions import WTAC
from prototorch.models.architectures.base import BaseYArchitecture
class WTACompetitionMixin(BaseYArchitecture):
"""
Winner Take All Competition
A competition layer that uses the winner-take-all strategy.
"""
# HyperParameters
# ----------------------------------------------------------------------------------------------------
@dataclass
class HyperParameters(BaseYArchitecture.HyperParameters):
"""
No hyperparameters.
"""
# Steps
# ----------------------------------------------------------------------------------------------------
def init_inference(self, hparams: HyperParameters):
self.competition_layer = WTAC()
def inference(self, comparison_measures, components):
comp_labels = components[1]
return self.competition_layer(comparison_measures, comp_labels)

View File

@@ -0,0 +1,64 @@
from dataclasses import dataclass
from prototorch.core.components import LabeledComponents
from prototorch.core.initializers import (
AbstractComponentsInitializer,
LabelsInitializer,
ZerosCompInitializer,
)
from prototorch.models import BaseYArchitecture
class SupervisedArchitecture(BaseYArchitecture):
"""
Supervised Architecture
An architecture that uses labeled Components as component Layer.
"""
components_layer: LabeledComponents
# HyperParameters
# ----------------------------------------------------------------------------------------------------
@dataclass
class HyperParameters:
"""
distribution: A valid prototype distribution. No default possible.
components_initializer: An implementation of AbstractComponentsInitializer. No default possible.
"""
distribution: "dict[str, int]"
component_initializer: AbstractComponentsInitializer
# Steps
# ----------------------------------------------------------------------------------------------------
def init_components(self, hparams: HyperParameters):
if hparams.component_initializer is not None:
self.components_layer = LabeledComponents(
distribution=hparams.distribution,
components_initializer=hparams.component_initializer,
labels_initializer=LabelsInitializer(),
)
proto_shape = self.components_layer.components.shape[1:]
self.hparams["initialized_proto_shape"] = proto_shape
else:
# when restoring a checkpointed model
self.components_layer = LabeledComponents(
distribution=hparams.distribution,
components_initializer=ZerosCompInitializer(
self.hparams["initialized_proto_shape"]),
)
# Properties
# ----------------------------------------------------------------------------------------------------
@property
def prototypes(self):
"""
Returns the position of the prototypes.
"""
return self.components_layer.components.detach().cpu()
@property
def prototype_labels(self):
"""
Returns the labels of the prototypes.
"""
return self.components_layer.labels.detach().cpu()

View File

@@ -0,0 +1,42 @@
from dataclasses import dataclass, field
from prototorch.core.losses import GLVQLoss
from prototorch.models.architectures.base import BaseYArchitecture
class GLVQLossMixin(BaseYArchitecture):
"""
GLVQ Loss
A loss layer that uses the Generalized Learning Vector Quantization (GLVQ) loss.
"""
# HyperParameters
# ----------------------------------------------------------------------------------------------------
@dataclass
class HyperParameters(BaseYArchitecture.HyperParameters):
"""
margin: The margin of the GLVQ loss. Default: 0.0.
transfer_fn: Transfer function to use. Default: sigmoid_beta.
transfer_args: Keyword arguments for the transfer function. Default: {beta: 10.0}.
"""
margin: float = 0.0
transfer_fn: str = "sigmoid_beta"
transfer_args: dict = field(default_factory=lambda: dict(beta=10.0))
# Steps
# ----------------------------------------------------------------------------------------------------
def init_loss(self, hparams: HyperParameters):
self.loss_layer = GLVQLoss(
margin=hparams.margin,
transfer_fn=hparams.transfer_fn,
**hparams.transfer_args,
)
def loss(self, comparison_measures, batch, components):
target = batch[1]
comp_labels = components[1]
loss = self.loss_layer(comparison_measures, target, comp_labels)
self.log('loss', loss)
return loss

View File

@@ -0,0 +1,73 @@
from dataclasses import dataclass, field
from typing import Type
import torch
from prototorch.models import BaseYArchitecture
from torch.nn.parameter import Parameter
class SingleLearningRateMixin(BaseYArchitecture):
"""
Single Learning Rate
All parameters are updated with a single learning rate.
"""
# HyperParameters
# ----------------------------------------------------------------------------------------------------
@dataclass
class HyperParameters(BaseYArchitecture.HyperParameters):
"""
lr: The learning rate. Default: 0.1.
optimizer: The optimizer to use. Default: torch.optim.Adam.
"""
lr: float = 0.1
optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam
# Hooks
# ----------------------------------------------------------------------------------------------------
def configure_optimizers(self):
return self.hparams.optimizer(self.parameters(),
lr=self.hparams.lr) # type: ignore
class MultipleLearningRateMixin(BaseYArchitecture):
"""
Multiple Learning Rates
Define Different Learning Rates for different parameters.
"""
# HyperParameters
# ----------------------------------------------------------------------------------------------------
@dataclass
class HyperParameters(BaseYArchitecture.HyperParameters):
"""
lr: The learning rate. Default: 0.1.
optimizer: The optimizer to use. Default: torch.optim.Adam.
"""
lr: dict = field(default_factory=lambda: dict())
optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam
# Hooks
# ----------------------------------------------------------------------------------------------------
def configure_optimizers(self):
optimizers = []
for name, lr in self.hparams.lr.items():
if not hasattr(self, name):
raise ValueError(f"{name} is not a parameter of {self}")
else:
model_part = getattr(self, name)
if isinstance(model_part, Parameter):
optimizers.append(
self.hparams.optimizer(
[model_part],
lr=lr, # type: ignore
))
elif hasattr(model_part, "parameters"):
optimizers.append(
self.hparams.optimizer(
model_part.parameters(),
lr=lr, # type: ignore
))
return optimizers