fix: fix problems with y architecture and checkpoint
This commit is contained in:
parent
fe729781fc
commit
be7d7f43bd
@ -13,8 +13,8 @@ from torch.utils.data import DataLoader
|
|||||||
|
|
||||||
# ##############################################################################
|
# ##############################################################################
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
|
|
||||||
|
def main():
|
||||||
# ------------------------------------------------------------
|
# ------------------------------------------------------------
|
||||||
# DATA
|
# DATA
|
||||||
# ------------------------------------------------------------
|
# ------------------------------------------------------------
|
||||||
@ -51,7 +51,7 @@ if __name__ == "__main__":
|
|||||||
# Create Model
|
# Create Model
|
||||||
model = GMLVQ(hyperparameters)
|
model = GMLVQ(hyperparameters)
|
||||||
|
|
||||||
print(model)
|
print(model.hparams)
|
||||||
|
|
||||||
# ------------------------------------------------------------
|
# ------------------------------------------------------------
|
||||||
# TRAINING
|
# TRAINING
|
||||||
@ -74,15 +74,27 @@ if __name__ == "__main__":
|
|||||||
vis = VisGMLVQ2D(data=train_ds)
|
vis = VisGMLVQ2D(data=train_ds)
|
||||||
|
|
||||||
# Define trainer
|
# Define trainer
|
||||||
trainer = pl.Trainer(
|
trainer = pl.Trainer(callbacks=[
|
||||||
callbacks=[
|
vis,
|
||||||
vis,
|
stopping_criterion,
|
||||||
stopping_criterion,
|
es,
|
||||||
es,
|
PlotLambdaMatrixToTensorboard(),
|
||||||
PlotLambdaMatrixToTensorboard(),
|
], )
|
||||||
],
|
|
||||||
max_epochs=1000,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Train
|
# Train
|
||||||
trainer.fit(model, train_loader)
|
trainer.fit(model, train_loader)
|
||||||
|
|
||||||
|
# Manual save
|
||||||
|
trainer.save_checkpoint("./y_arch.ckpt")
|
||||||
|
|
||||||
|
# Load saved model
|
||||||
|
new_model = GMLVQ.load_from_checkpoint(
|
||||||
|
checkpoint_path="./y_arch.ckpt",
|
||||||
|
strict=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(new_model.hparams)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
@ -3,8 +3,11 @@ Proto Y Architecture
|
|||||||
|
|
||||||
Network architecture for Component based Learning.
|
Network architecture for Component based Learning.
|
||||||
"""
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import (
|
from typing import (
|
||||||
|
Any,
|
||||||
Callable,
|
Callable,
|
||||||
Dict,
|
Dict,
|
||||||
Set,
|
Set,
|
||||||
@ -23,15 +26,20 @@ class BaseYArchitecture(pl.LightningModule):
|
|||||||
...
|
...
|
||||||
|
|
||||||
# Fields
|
# Fields
|
||||||
registered_metrics: Dict[Type[Metric], Metric] = {}
|
registered_metrics: dict[type[Metric], Metric] = {}
|
||||||
registered_metric_callbacks: Dict[Type[Metric], Set[Callable]] = {}
|
registered_metric_callbacks: dict[type[Metric], set[Callable]] = {}
|
||||||
|
|
||||||
# Type Hints for Necessary Fields
|
# Type Hints for Necessary Fields
|
||||||
components_layer: torch.nn.Module
|
components_layer: torch.nn.Module
|
||||||
|
|
||||||
def __init__(self, hparams) -> None:
|
def __init__(self, hparams) -> None:
|
||||||
|
if type(hparams) is dict:
|
||||||
|
hparams = self.HyperParameters(**hparams)
|
||||||
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
self.save_hyperparameters(hparams.__dict__)
|
||||||
|
|
||||||
# Common Steps
|
# Common Steps
|
||||||
self.init_components(hparams)
|
self.init_components(hparams)
|
||||||
self.init_latent(hparams)
|
self.init_latent(hparams)
|
||||||
@ -165,7 +173,7 @@ class BaseYArchitecture(pl.LightningModule):
|
|||||||
def register_torchmetric(
|
def register_torchmetric(
|
||||||
self,
|
self,
|
||||||
name: Callable,
|
name: Callable,
|
||||||
metric: Type[Metric],
|
metric: type[Metric],
|
||||||
**metric_kwargs,
|
**metric_kwargs,
|
||||||
):
|
):
|
||||||
if metric not in self.registered_metrics:
|
if metric not in self.registered_metrics:
|
||||||
@ -210,3 +218,9 @@ class BaseYArchitecture(pl.LightningModule):
|
|||||||
# Other Hooks
|
# Other Hooks
|
||||||
def training_epoch_end(self, outs) -> None:
|
def training_epoch_end(self, outs) -> None:
|
||||||
self.update_metrics_epoch()
|
self.update_metrics_epoch()
|
||||||
|
|
||||||
|
def on_save_checkpoint(self, checkpoint: dict[str, Any]) -> None:
|
||||||
|
checkpoint["hyper_parameters"] = {
|
||||||
|
'hparams': checkpoint["hyper_parameters"]
|
||||||
|
}
|
||||||
|
return super().on_save_checkpoint(checkpoint)
|
||||||
|
@ -24,17 +24,11 @@ class SingleLearningRateMixin(BaseYArchitecture):
|
|||||||
lr: float = 0.1
|
lr: float = 0.1
|
||||||
optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam
|
optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam
|
||||||
|
|
||||||
# Steps
|
|
||||||
# ----------------------------------------------------------------------------------------------------
|
|
||||||
def __init__(self, hparams: HyperParameters) -> None:
|
|
||||||
super().__init__(hparams)
|
|
||||||
self.lr = hparams.lr
|
|
||||||
self.optimizer = hparams.optimizer
|
|
||||||
|
|
||||||
# Hooks
|
# Hooks
|
||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
||||||
def configure_optimizers(self):
|
def configure_optimizers(self):
|
||||||
return self.optimizer(self.parameters(), lr=self.lr) # type: ignore
|
return self.hparams.optimizer(self.parameters(),
|
||||||
|
lr=self.hparams.lr) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
class MultipleLearningRateMixin(BaseYArchitecture):
|
class MultipleLearningRateMixin(BaseYArchitecture):
|
||||||
@ -55,31 +49,24 @@ class MultipleLearningRateMixin(BaseYArchitecture):
|
|||||||
lr: dict = field(default_factory=lambda: dict())
|
lr: dict = field(default_factory=lambda: dict())
|
||||||
optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam
|
optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam
|
||||||
|
|
||||||
# Steps
|
|
||||||
# ----------------------------------------------------------------------------------------------------
|
|
||||||
def __init__(self, hparams: HyperParameters) -> None:
|
|
||||||
super().__init__(hparams)
|
|
||||||
self.lr = hparams.lr
|
|
||||||
self.optimizer = hparams.optimizer
|
|
||||||
|
|
||||||
# Hooks
|
# Hooks
|
||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
||||||
def configure_optimizers(self):
|
def configure_optimizers(self):
|
||||||
optimizers = []
|
optimizers = []
|
||||||
for name, lr in self.lr.items():
|
for name, lr in self.hparams.lr.items():
|
||||||
if not hasattr(self, name):
|
if not hasattr(self, name):
|
||||||
raise ValueError(f"{name} is not a parameter of {self}")
|
raise ValueError(f"{name} is not a parameter of {self}")
|
||||||
else:
|
else:
|
||||||
model_part = getattr(self, name)
|
model_part = getattr(self, name)
|
||||||
if isinstance(model_part, Parameter):
|
if isinstance(model_part, Parameter):
|
||||||
optimizers.append(
|
optimizers.append(
|
||||||
self.optimizer(
|
self.hparams.optimizer(
|
||||||
[model_part],
|
[model_part],
|
||||||
lr=lr, # type: ignore
|
lr=lr, # type: ignore
|
||||||
))
|
))
|
||||||
elif hasattr(model_part, "parameters"):
|
elif hasattr(model_part, "parameters"):
|
||||||
optimizers.append(
|
optimizers.append(
|
||||||
self.optimizer(
|
self.hparams.optimizer(
|
||||||
model_part.parameters(),
|
model_part.parameters(),
|
||||||
lr=lr, # type: ignore
|
lr=lr, # type: ignore
|
||||||
))
|
))
|
||||||
|
@ -1,5 +1,7 @@
|
|||||||
from .glvq import GLVQ
|
from .glvq import GLVQ
|
||||||
|
from .gmlvq import GMLVQ
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"GLVQ",
|
"GLVQ",
|
||||||
|
"GMLVQ",
|
||||||
]
|
]
|
||||||
|
Loading…
Reference in New Issue
Block a user