4 Commits

Author SHA1 Message Date
Alexander Engelsberger
ed83138e1f build: bump version 1.0.0a3 → 1.0.0a4 2022-06-12 11:52:06 +02:00
Alexander Engelsberger
1be7d7ec09 fix: dont save component initializer as hparm 2022-06-12 11:40:33 +02:00
Alexander Engelsberger
60d2a1d2c9 fix: dont save prototype initializer in yarch checkpoint 2022-06-12 11:12:55 +02:00
Alexander Engelsberger
be7d7f43bd fix: fix problems with y architecture and checkpoint 2022-06-12 10:36:15 +02:00
9 changed files with 73 additions and 47 deletions

View File

@@ -1,5 +1,5 @@
[bumpversion]
current_version = 1.0.0a3
current_version = 1.0.0a4
commit = True
tag = True
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)((?P<release>[a-zA-Z0-9_.-]+))?

View File

@@ -23,7 +23,7 @@ author = "Jensun Ravichandran"
# The full version, including alpha/beta/rc tags
#
release = "1.0.0-a3"
release = "1.0.0-a4"
# -- General configuration ---------------------------------------------------

View File

@@ -13,8 +13,8 @@ from torch.utils.data import DataLoader
# ##############################################################################
if __name__ == "__main__":
def main():
# ------------------------------------------------------------
# DATA
# ------------------------------------------------------------
@@ -51,7 +51,7 @@ if __name__ == "__main__":
# Create Model
model = GMLVQ(hyperparameters)
print(model)
print(model.hparams)
# ------------------------------------------------------------
# TRAINING
@@ -74,15 +74,27 @@ if __name__ == "__main__":
vis = VisGMLVQ2D(data=train_ds)
# Define trainer
trainer = pl.Trainer(
callbacks=[
vis,
stopping_criterion,
es,
PlotLambdaMatrixToTensorboard(),
],
max_epochs=1000,
)
trainer = pl.Trainer(callbacks=[
vis,
stopping_criterion,
es,
PlotLambdaMatrixToTensorboard(),
], )
# Train
trainer.fit(model, train_loader)
# Manual save
trainer.save_checkpoint("./y_arch.ckpt")
# Load saved model
new_model = GMLVQ.load_from_checkpoint(
checkpoint_path="./y_arch.ckpt",
strict=True,
)
print(new_model.hparams)
if __name__ == "__main__":
main()

View File

@@ -36,4 +36,4 @@ from .unsupervised import (
)
from .vis import *
__version__ = "1.0.0-a3"
__version__ = "1.0.0-a4"

View File

@@ -3,13 +3,10 @@ Proto Y Architecture
Network architecture for Component based Learning.
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import (
Callable,
Dict,
Set,
Type,
)
from typing import Any, Callable
import pytorch_lightning as pl
import torch
@@ -23,13 +20,24 @@ class BaseYArchitecture(pl.LightningModule):
...
# Fields
registered_metrics: Dict[Type[Metric], Metric] = {}
registered_metric_callbacks: Dict[Type[Metric], Set[Callable]] = {}
registered_metrics: dict[type[Metric], Metric] = {}
registered_metric_callbacks: dict[type[Metric], set[Callable]] = {}
# Type Hints for Necessary Fields
components_layer: torch.nn.Module
def __init__(self, hparams) -> None:
if type(hparams) is dict:
self.save_hyperparameters(hparams)
# TODO: => Move into Component Child
del hparams["initialized_proto_shape"]
hparams = self.HyperParameters(**hparams)
else:
self.save_hyperparameters(
hparams.__dict__,
ignore=["component_initializer"],
)
super().__init__()
# Common Steps
@@ -165,7 +173,7 @@ class BaseYArchitecture(pl.LightningModule):
def register_torchmetric(
self,
name: Callable,
metric: Type[Metric],
metric: type[Metric],
**metric_kwargs,
):
if metric not in self.registered_metrics:
@@ -210,3 +218,9 @@ class BaseYArchitecture(pl.LightningModule):
# Other Hooks
def training_epoch_end(self, outs) -> None:
self.update_metrics_epoch()
def on_save_checkpoint(self, checkpoint: dict[str, Any]) -> None:
checkpoint["hyper_parameters"] = {
'hparams': checkpoint["hyper_parameters"]
}
return super().on_save_checkpoint(checkpoint)

View File

@@ -4,6 +4,7 @@ from prototorch.core.components import LabeledComponents
from prototorch.core.initializers import (
AbstractComponentsInitializer,
LabelsInitializer,
ZerosCompInitializer,
)
from prototorch.y import BaseYArchitecture
@@ -30,11 +31,21 @@ class SupervisedArchitecture(BaseYArchitecture):
# Steps
# ----------------------------------------------------------------------------------------------------
def init_components(self, hparams: HyperParameters):
self.components_layer = LabeledComponents(
distribution=hparams.distribution,
components_initializer=hparams.component_initializer,
labels_initializer=LabelsInitializer(),
)
if hparams.component_initializer is not None:
self.components_layer = LabeledComponents(
distribution=hparams.distribution,
components_initializer=hparams.component_initializer,
labels_initializer=LabelsInitializer(),
)
proto_shape = self.components_layer.components.shape[1:]
self.hparams["initialized_proto_shape"] = proto_shape
else:
# when restoring a checkpointed model
self.components_layer = LabeledComponents(
distribution=hparams.distribution,
components_initializer=ZerosCompInitializer(
self.hparams["initialized_proto_shape"]),
)
# Properties
# ----------------------------------------------------------------------------------------------------

View File

@@ -24,17 +24,11 @@ class SingleLearningRateMixin(BaseYArchitecture):
lr: float = 0.1
optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam
# Steps
# ----------------------------------------------------------------------------------------------------
def __init__(self, hparams: HyperParameters) -> None:
super().__init__(hparams)
self.lr = hparams.lr
self.optimizer = hparams.optimizer
# Hooks
# ----------------------------------------------------------------------------------------------------
def configure_optimizers(self):
return self.optimizer(self.parameters(), lr=self.lr) # type: ignore
return self.hparams.optimizer(self.parameters(),
lr=self.hparams.lr) # type: ignore
class MultipleLearningRateMixin(BaseYArchitecture):
@@ -55,31 +49,24 @@ class MultipleLearningRateMixin(BaseYArchitecture):
lr: dict = field(default_factory=lambda: dict())
optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam
# Steps
# ----------------------------------------------------------------------------------------------------
def __init__(self, hparams: HyperParameters) -> None:
super().__init__(hparams)
self.lr = hparams.lr
self.optimizer = hparams.optimizer
# Hooks
# ----------------------------------------------------------------------------------------------------
def configure_optimizers(self):
optimizers = []
for name, lr in self.lr.items():
for name, lr in self.hparams.lr.items():
if not hasattr(self, name):
raise ValueError(f"{name} is not a parameter of {self}")
else:
model_part = getattr(self, name)
if isinstance(model_part, Parameter):
optimizers.append(
self.optimizer(
self.hparams.optimizer(
[model_part],
lr=lr, # type: ignore
))
elif hasattr(model_part, "parameters"):
optimizers.append(
self.optimizer(
self.hparams.optimizer(
model_part.parameters(),
lr=lr, # type: ignore
))

View File

@@ -1,5 +1,7 @@
from .glvq import GLVQ
from .gmlvq import GMLVQ
__all__ = [
"GLVQ",
"GMLVQ",
]

View File

@@ -55,7 +55,7 @@ ALL = CLI + DEV + DOCS + EXAMPLES + TESTS
setup(
name=safe_name("prototorch_" + PLUGIN_NAME),
version="1.0.0-a3",
version="1.0.0-a4",
description="Pre-packaged prototype-based "
"machine learning models using ProtoTorch and PyTorch-Lightning.",
long_description=long_description,