fix: fix problems with y architecture and checkpoint

This commit is contained in:
Alexander Engelsberger 2022-06-12 10:36:15 +02:00
parent fe729781fc
commit be7d7f43bd
4 changed files with 47 additions and 32 deletions

View File

@ -13,8 +13,8 @@ from torch.utils.data import DataLoader
# ############################################################################## # ##############################################################################
if __name__ == "__main__":
def main():
# ------------------------------------------------------------ # ------------------------------------------------------------
# DATA # DATA
# ------------------------------------------------------------ # ------------------------------------------------------------
@ -51,7 +51,7 @@ if __name__ == "__main__":
# Create Model # Create Model
model = GMLVQ(hyperparameters) model = GMLVQ(hyperparameters)
print(model) print(model.hparams)
# ------------------------------------------------------------ # ------------------------------------------------------------
# TRAINING # TRAINING
@ -74,15 +74,27 @@ if __name__ == "__main__":
vis = VisGMLVQ2D(data=train_ds) vis = VisGMLVQ2D(data=train_ds)
# Define trainer # Define trainer
trainer = pl.Trainer( trainer = pl.Trainer(callbacks=[
callbacks=[
vis, vis,
stopping_criterion, stopping_criterion,
es, es,
PlotLambdaMatrixToTensorboard(), PlotLambdaMatrixToTensorboard(),
], ], )
max_epochs=1000,
)
# Train # Train
trainer.fit(model, train_loader) trainer.fit(model, train_loader)
# Manual save
trainer.save_checkpoint("./y_arch.ckpt")
# Load saved model
new_model = GMLVQ.load_from_checkpoint(
checkpoint_path="./y_arch.ckpt",
strict=True,
)
print(new_model.hparams)
if __name__ == "__main__":
main()

View File

@ -3,8 +3,11 @@ Proto Y Architecture
Network architecture for Component based Learning. Network architecture for Component based Learning.
""" """
from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass
from typing import ( from typing import (
Any,
Callable, Callable,
Dict, Dict,
Set, Set,
@ -23,15 +26,20 @@ class BaseYArchitecture(pl.LightningModule):
... ...
# Fields # Fields
registered_metrics: Dict[Type[Metric], Metric] = {} registered_metrics: dict[type[Metric], Metric] = {}
registered_metric_callbacks: Dict[Type[Metric], Set[Callable]] = {} registered_metric_callbacks: dict[type[Metric], set[Callable]] = {}
# Type Hints for Necessary Fields # Type Hints for Necessary Fields
components_layer: torch.nn.Module components_layer: torch.nn.Module
def __init__(self, hparams) -> None: def __init__(self, hparams) -> None:
if type(hparams) is dict:
hparams = self.HyperParameters(**hparams)
super().__init__() super().__init__()
self.save_hyperparameters(hparams.__dict__)
# Common Steps # Common Steps
self.init_components(hparams) self.init_components(hparams)
self.init_latent(hparams) self.init_latent(hparams)
@ -165,7 +173,7 @@ class BaseYArchitecture(pl.LightningModule):
def register_torchmetric( def register_torchmetric(
self, self,
name: Callable, name: Callable,
metric: Type[Metric], metric: type[Metric],
**metric_kwargs, **metric_kwargs,
): ):
if metric not in self.registered_metrics: if metric not in self.registered_metrics:
@ -210,3 +218,9 @@ class BaseYArchitecture(pl.LightningModule):
# Other Hooks # Other Hooks
def training_epoch_end(self, outs) -> None: def training_epoch_end(self, outs) -> None:
self.update_metrics_epoch() self.update_metrics_epoch()
def on_save_checkpoint(self, checkpoint: dict[str, Any]) -> None:
checkpoint["hyper_parameters"] = {
'hparams': checkpoint["hyper_parameters"]
}
return super().on_save_checkpoint(checkpoint)

View File

@ -24,17 +24,11 @@ class SingleLearningRateMixin(BaseYArchitecture):
lr: float = 0.1 lr: float = 0.1
optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam
# Steps
# ----------------------------------------------------------------------------------------------------
def __init__(self, hparams: HyperParameters) -> None:
super().__init__(hparams)
self.lr = hparams.lr
self.optimizer = hparams.optimizer
# Hooks # Hooks
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------
def configure_optimizers(self): def configure_optimizers(self):
return self.optimizer(self.parameters(), lr=self.lr) # type: ignore return self.hparams.optimizer(self.parameters(),
lr=self.hparams.lr) # type: ignore
class MultipleLearningRateMixin(BaseYArchitecture): class MultipleLearningRateMixin(BaseYArchitecture):
@ -55,31 +49,24 @@ class MultipleLearningRateMixin(BaseYArchitecture):
lr: dict = field(default_factory=lambda: dict()) lr: dict = field(default_factory=lambda: dict())
optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam
# Steps
# ----------------------------------------------------------------------------------------------------
def __init__(self, hparams: HyperParameters) -> None:
super().__init__(hparams)
self.lr = hparams.lr
self.optimizer = hparams.optimizer
# Hooks # Hooks
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------
def configure_optimizers(self): def configure_optimizers(self):
optimizers = [] optimizers = []
for name, lr in self.lr.items(): for name, lr in self.hparams.lr.items():
if not hasattr(self, name): if not hasattr(self, name):
raise ValueError(f"{name} is not a parameter of {self}") raise ValueError(f"{name} is not a parameter of {self}")
else: else:
model_part = getattr(self, name) model_part = getattr(self, name)
if isinstance(model_part, Parameter): if isinstance(model_part, Parameter):
optimizers.append( optimizers.append(
self.optimizer( self.hparams.optimizer(
[model_part], [model_part],
lr=lr, # type: ignore lr=lr, # type: ignore
)) ))
elif hasattr(model_part, "parameters"): elif hasattr(model_part, "parameters"):
optimizers.append( optimizers.append(
self.optimizer( self.hparams.optimizer(
model_part.parameters(), model_part.parameters(),
lr=lr, # type: ignore lr=lr, # type: ignore
)) ))

View File

@ -1,5 +1,7 @@
from .glvq import GLVQ from .glvq import GLVQ
from .gmlvq import GMLVQ
__all__ = [ __all__ = [
"GLVQ", "GLVQ",
"GMLVQ",
] ]