fix: fix problems with y architecture and checkpoint

This commit is contained in:
Alexander Engelsberger
2022-06-12 10:36:15 +02:00
parent fe729781fc
commit be7d7f43bd
4 changed files with 47 additions and 32 deletions

View File

@@ -3,8 +3,11 @@ Proto Y Architecture
Network architecture for Component based Learning.
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import (
Any,
Callable,
Dict,
Set,
@@ -23,15 +26,20 @@ class BaseYArchitecture(pl.LightningModule):
...
# Fields
registered_metrics: Dict[Type[Metric], Metric] = {}
registered_metric_callbacks: Dict[Type[Metric], Set[Callable]] = {}
registered_metrics: dict[type[Metric], Metric] = {}
registered_metric_callbacks: dict[type[Metric], set[Callable]] = {}
# Type Hints for Necessary Fields
components_layer: torch.nn.Module
def __init__(self, hparams) -> None:
if type(hparams) is dict:
hparams = self.HyperParameters(**hparams)
super().__init__()
self.save_hyperparameters(hparams.__dict__)
# Common Steps
self.init_components(hparams)
self.init_latent(hparams)
@@ -165,7 +173,7 @@ class BaseYArchitecture(pl.LightningModule):
def register_torchmetric(
self,
name: Callable,
metric: Type[Metric],
metric: type[Metric],
**metric_kwargs,
):
if metric not in self.registered_metrics:
@@ -210,3 +218,9 @@ class BaseYArchitecture(pl.LightningModule):
# Other Hooks
def training_epoch_end(self, outs) -> None:
self.update_metrics_epoch()
def on_save_checkpoint(self, checkpoint: dict[str, Any]) -> None:
checkpoint["hyper_parameters"] = {
'hparams': checkpoint["hyper_parameters"]
}
return super().on_save_checkpoint(checkpoint)

View File

@@ -24,17 +24,11 @@ class SingleLearningRateMixin(BaseYArchitecture):
lr: float = 0.1
optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam
# Steps
# ----------------------------------------------------------------------------------------------------
def __init__(self, hparams: HyperParameters) -> None:
super().__init__(hparams)
self.lr = hparams.lr
self.optimizer = hparams.optimizer
# Hooks
# ----------------------------------------------------------------------------------------------------
def configure_optimizers(self):
return self.optimizer(self.parameters(), lr=self.lr) # type: ignore
return self.hparams.optimizer(self.parameters(),
lr=self.hparams.lr) # type: ignore
class MultipleLearningRateMixin(BaseYArchitecture):
@@ -55,31 +49,24 @@ class MultipleLearningRateMixin(BaseYArchitecture):
lr: dict = field(default_factory=lambda: dict())
optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam
# Steps
# ----------------------------------------------------------------------------------------------------
def __init__(self, hparams: HyperParameters) -> None:
super().__init__(hparams)
self.lr = hparams.lr
self.optimizer = hparams.optimizer
# Hooks
# ----------------------------------------------------------------------------------------------------
def configure_optimizers(self):
optimizers = []
for name, lr in self.lr.items():
for name, lr in self.hparams.lr.items():
if not hasattr(self, name):
raise ValueError(f"{name} is not a parameter of {self}")
else:
model_part = getattr(self, name)
if isinstance(model_part, Parameter):
optimizers.append(
self.optimizer(
self.hparams.optimizer(
[model_part],
lr=lr, # type: ignore
))
elif hasattr(model_part, "parameters"):
optimizers.append(
self.optimizer(
self.hparams.optimizer(
model_part.parameters(),
lr=lr, # type: ignore
))