fix: fix problems with y architecture and checkpoint

This commit is contained in:
Alexander Engelsberger 2022-06-12 10:36:15 +02:00
parent fe729781fc
commit be7d7f43bd
4 changed files with 47 additions and 32 deletions

View File

@ -13,8 +13,8 @@ from torch.utils.data import DataLoader
# ##############################################################################
if __name__ == "__main__":
def main():
# ------------------------------------------------------------
# DATA
# ------------------------------------------------------------
@ -51,7 +51,7 @@ if __name__ == "__main__":
# Create Model
model = GMLVQ(hyperparameters)
print(model)
print(model.hparams)
# ------------------------------------------------------------
# TRAINING
@ -74,15 +74,27 @@ if __name__ == "__main__":
vis = VisGMLVQ2D(data=train_ds)
# Define trainer
trainer = pl.Trainer(
callbacks=[
vis,
stopping_criterion,
es,
PlotLambdaMatrixToTensorboard(),
],
max_epochs=1000,
)
trainer = pl.Trainer(callbacks=[
vis,
stopping_criterion,
es,
PlotLambdaMatrixToTensorboard(),
], )
# Train
trainer.fit(model, train_loader)
# Manual save
trainer.save_checkpoint("./y_arch.ckpt")
# Load saved model
new_model = GMLVQ.load_from_checkpoint(
checkpoint_path="./y_arch.ckpt",
strict=True,
)
print(new_model.hparams)
if __name__ == "__main__":
main()

View File

@ -3,8 +3,11 @@ Proto Y Architecture
Network architecture for Component based Learning.
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import (
Any,
Callable,
Dict,
Set,
@ -23,15 +26,20 @@ class BaseYArchitecture(pl.LightningModule):
...
# Fields
registered_metrics: Dict[Type[Metric], Metric] = {}
registered_metric_callbacks: Dict[Type[Metric], Set[Callable]] = {}
registered_metrics: dict[type[Metric], Metric] = {}
registered_metric_callbacks: dict[type[Metric], set[Callable]] = {}
# Type Hints for Necessary Fields
components_layer: torch.nn.Module
def __init__(self, hparams) -> None:
if type(hparams) is dict:
hparams = self.HyperParameters(**hparams)
super().__init__()
self.save_hyperparameters(hparams.__dict__)
# Common Steps
self.init_components(hparams)
self.init_latent(hparams)
@ -165,7 +173,7 @@ class BaseYArchitecture(pl.LightningModule):
def register_torchmetric(
self,
name: Callable,
metric: Type[Metric],
metric: type[Metric],
**metric_kwargs,
):
if metric not in self.registered_metrics:
@ -210,3 +218,9 @@ class BaseYArchitecture(pl.LightningModule):
# Other Hooks
def training_epoch_end(self, outs) -> None:
self.update_metrics_epoch()
def on_save_checkpoint(self, checkpoint: dict[str, Any]) -> None:
checkpoint["hyper_parameters"] = {
'hparams': checkpoint["hyper_parameters"]
}
return super().on_save_checkpoint(checkpoint)

View File

@ -24,17 +24,11 @@ class SingleLearningRateMixin(BaseYArchitecture):
lr: float = 0.1
optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam
# Steps
# ----------------------------------------------------------------------------------------------------
def __init__(self, hparams: HyperParameters) -> None:
super().__init__(hparams)
self.lr = hparams.lr
self.optimizer = hparams.optimizer
# Hooks
# ----------------------------------------------------------------------------------------------------
def configure_optimizers(self):
return self.optimizer(self.parameters(), lr=self.lr) # type: ignore
return self.hparams.optimizer(self.parameters(),
lr=self.hparams.lr) # type: ignore
class MultipleLearningRateMixin(BaseYArchitecture):
@ -55,31 +49,24 @@ class MultipleLearningRateMixin(BaseYArchitecture):
lr: dict = field(default_factory=lambda: dict())
optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam
# Steps
# ----------------------------------------------------------------------------------------------------
def __init__(self, hparams: HyperParameters) -> None:
super().__init__(hparams)
self.lr = hparams.lr
self.optimizer = hparams.optimizer
# Hooks
# ----------------------------------------------------------------------------------------------------
def configure_optimizers(self):
optimizers = []
for name, lr in self.lr.items():
for name, lr in self.hparams.lr.items():
if not hasattr(self, name):
raise ValueError(f"{name} is not a parameter of {self}")
else:
model_part = getattr(self, name)
if isinstance(model_part, Parameter):
optimizers.append(
self.optimizer(
self.hparams.optimizer(
[model_part],
lr=lr, # type: ignore
))
elif hasattr(model_part, "parameters"):
optimizers.append(
self.optimizer(
self.hparams.optimizer(
model_part.parameters(),
lr=lr, # type: ignore
))

View File

@ -1,5 +1,7 @@
from .glvq import GLVQ
from .gmlvq import GMLVQ
__all__ = [
"GLVQ",
"GMLVQ",
]