Compare commits
6 Commits
Author | SHA1 | Date | |
---|---|---|---|
|
46ec7b07d7 | ||
|
07dab5a5ca | ||
|
ed83138e1f | ||
|
1be7d7ec09 | ||
|
60d2a1d2c9 | ||
|
be7d7f43bd |
@@ -1,5 +1,5 @@
|
||||
[bumpversion]
|
||||
current_version = 1.0.0a3
|
||||
current_version = 1.0.0a5
|
||||
commit = True
|
||||
tag = True
|
||||
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)((?P<release>[a-zA-Z0-9_.-]+))?
|
||||
|
@@ -23,7 +23,7 @@ author = "Jensun Ravichandran"
|
||||
|
||||
# The full version, including alpha/beta/rc tags
|
||||
#
|
||||
release = "1.0.0-a3"
|
||||
release = "1.0.0-a5"
|
||||
|
||||
# -- General configuration ---------------------------------------------------
|
||||
|
||||
|
@@ -13,8 +13,8 @@ from torch.utils.data import DataLoader
|
||||
|
||||
# ##############################################################################
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
def main():
|
||||
# ------------------------------------------------------------
|
||||
# DATA
|
||||
# ------------------------------------------------------------
|
||||
@@ -51,7 +51,7 @@ if __name__ == "__main__":
|
||||
# Create Model
|
||||
model = GMLVQ(hyperparameters)
|
||||
|
||||
print(model)
|
||||
print(model.hparams)
|
||||
|
||||
# ------------------------------------------------------------
|
||||
# TRAINING
|
||||
@@ -74,15 +74,27 @@ if __name__ == "__main__":
|
||||
vis = VisGMLVQ2D(data=train_ds)
|
||||
|
||||
# Define trainer
|
||||
trainer = pl.Trainer(
|
||||
callbacks=[
|
||||
trainer = pl.Trainer(callbacks=[
|
||||
vis,
|
||||
stopping_criterion,
|
||||
es,
|
||||
PlotLambdaMatrixToTensorboard(),
|
||||
],
|
||||
max_epochs=1000,
|
||||
)
|
||||
], )
|
||||
|
||||
# Train
|
||||
trainer.fit(model, train_loader)
|
||||
|
||||
# Manual save
|
||||
trainer.save_checkpoint("./y_arch.ckpt")
|
||||
|
||||
# Load saved model
|
||||
new_model = GMLVQ.load_from_checkpoint(
|
||||
checkpoint_path="./y_arch.ckpt",
|
||||
strict=True,
|
||||
)
|
||||
|
||||
print(new_model.hparams)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
@@ -36,4 +36,4 @@ from .unsupervised import (
|
||||
)
|
||||
from .vis import *
|
||||
|
||||
__version__ = "1.0.0-a3"
|
||||
__version__ = "1.0.0-a5"
|
||||
|
@@ -3,13 +3,10 @@ Proto Y Architecture
|
||||
|
||||
Network architecture for Component based Learning.
|
||||
"""
|
||||
from dataclasses import dataclass
|
||||
from typing import (
|
||||
Callable,
|
||||
Dict,
|
||||
Set,
|
||||
Type,
|
||||
)
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import asdict, dataclass
|
||||
from typing import Any, Callable
|
||||
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
@@ -23,13 +20,23 @@ class BaseYArchitecture(pl.LightningModule):
|
||||
...
|
||||
|
||||
# Fields
|
||||
registered_metrics: Dict[Type[Metric], Metric] = {}
|
||||
registered_metric_callbacks: Dict[Type[Metric], Set[Callable]] = {}
|
||||
registered_metrics: dict[type[Metric], Metric] = {}
|
||||
registered_metric_callbacks: dict[type[Metric], set[Callable]] = {}
|
||||
|
||||
# Type Hints for Necessary Fields
|
||||
components_layer: torch.nn.Module
|
||||
|
||||
def __init__(self, hparams) -> None:
|
||||
if type(hparams) is dict:
|
||||
self.save_hyperparameters(hparams)
|
||||
# TODO: => Move into Component Child
|
||||
del hparams["initialized_proto_shape"]
|
||||
hparams = self.HyperParameters(**hparams)
|
||||
else:
|
||||
hparam_dict = asdict(hparams)
|
||||
hparam_dict["component_initializer"] = None
|
||||
self.save_hyperparameters(hparam_dict, )
|
||||
|
||||
super().__init__()
|
||||
|
||||
# Common Steps
|
||||
@@ -165,7 +172,7 @@ class BaseYArchitecture(pl.LightningModule):
|
||||
def register_torchmetric(
|
||||
self,
|
||||
name: Callable,
|
||||
metric: Type[Metric],
|
||||
metric: type[Metric],
|
||||
**metric_kwargs,
|
||||
):
|
||||
if metric not in self.registered_metrics:
|
||||
@@ -210,3 +217,9 @@ class BaseYArchitecture(pl.LightningModule):
|
||||
# Other Hooks
|
||||
def training_epoch_end(self, outs) -> None:
|
||||
self.update_metrics_epoch()
|
||||
|
||||
def on_save_checkpoint(self, checkpoint: dict[str, Any]) -> None:
|
||||
checkpoint["hyper_parameters"] = {
|
||||
'hparams': checkpoint["hyper_parameters"]
|
||||
}
|
||||
return super().on_save_checkpoint(checkpoint)
|
||||
|
@@ -4,6 +4,7 @@ from prototorch.core.components import LabeledComponents
|
||||
from prototorch.core.initializers import (
|
||||
AbstractComponentsInitializer,
|
||||
LabelsInitializer,
|
||||
ZerosCompInitializer,
|
||||
)
|
||||
from prototorch.y import BaseYArchitecture
|
||||
|
||||
@@ -30,11 +31,21 @@ class SupervisedArchitecture(BaseYArchitecture):
|
||||
# Steps
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def init_components(self, hparams: HyperParameters):
|
||||
if hparams.component_initializer is not None:
|
||||
self.components_layer = LabeledComponents(
|
||||
distribution=hparams.distribution,
|
||||
components_initializer=hparams.component_initializer,
|
||||
labels_initializer=LabelsInitializer(),
|
||||
)
|
||||
proto_shape = self.components_layer.components.shape[1:]
|
||||
self.hparams["initialized_proto_shape"] = proto_shape
|
||||
else:
|
||||
# when restoring a checkpointed model
|
||||
self.components_layer = LabeledComponents(
|
||||
distribution=hparams.distribution,
|
||||
components_initializer=ZerosCompInitializer(
|
||||
self.hparams["initialized_proto_shape"]),
|
||||
)
|
||||
|
||||
# Properties
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
|
@@ -24,17 +24,11 @@ class SingleLearningRateMixin(BaseYArchitecture):
|
||||
lr: float = 0.1
|
||||
optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam
|
||||
|
||||
# Steps
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def __init__(self, hparams: HyperParameters) -> None:
|
||||
super().__init__(hparams)
|
||||
self.lr = hparams.lr
|
||||
self.optimizer = hparams.optimizer
|
||||
|
||||
# Hooks
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def configure_optimizers(self):
|
||||
return self.optimizer(self.parameters(), lr=self.lr) # type: ignore
|
||||
return self.hparams.optimizer(self.parameters(),
|
||||
lr=self.hparams.lr) # type: ignore
|
||||
|
||||
|
||||
class MultipleLearningRateMixin(BaseYArchitecture):
|
||||
@@ -55,31 +49,24 @@ class MultipleLearningRateMixin(BaseYArchitecture):
|
||||
lr: dict = field(default_factory=lambda: dict())
|
||||
optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam
|
||||
|
||||
# Steps
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def __init__(self, hparams: HyperParameters) -> None:
|
||||
super().__init__(hparams)
|
||||
self.lr = hparams.lr
|
||||
self.optimizer = hparams.optimizer
|
||||
|
||||
# Hooks
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
def configure_optimizers(self):
|
||||
optimizers = []
|
||||
for name, lr in self.lr.items():
|
||||
for name, lr in self.hparams.lr.items():
|
||||
if not hasattr(self, name):
|
||||
raise ValueError(f"{name} is not a parameter of {self}")
|
||||
else:
|
||||
model_part = getattr(self, name)
|
||||
if isinstance(model_part, Parameter):
|
||||
optimizers.append(
|
||||
self.optimizer(
|
||||
self.hparams.optimizer(
|
||||
[model_part],
|
||||
lr=lr, # type: ignore
|
||||
))
|
||||
elif hasattr(model_part, "parameters"):
|
||||
optimizers.append(
|
||||
self.optimizer(
|
||||
self.hparams.optimizer(
|
||||
model_part.parameters(),
|
||||
lr=lr, # type: ignore
|
||||
))
|
||||
|
@@ -1,5 +1,7 @@
|
||||
from .glvq import GLVQ
|
||||
from .gmlvq import GMLVQ
|
||||
|
||||
__all__ = [
|
||||
"GLVQ",
|
||||
"GMLVQ",
|
||||
]
|
||||
|
2
setup.py
2
setup.py
@@ -55,7 +55,7 @@ ALL = CLI + DEV + DOCS + EXAMPLES + TESTS
|
||||
|
||||
setup(
|
||||
name=safe_name("prototorch_" + PLUGIN_NAME),
|
||||
version="1.0.0-a3",
|
||||
version="1.0.0-a5",
|
||||
description="Pre-packaged prototype-based "
|
||||
"machine learning models using ProtoTorch and PyTorch-Lightning.",
|
||||
long_description=long_description,
|
||||
|
Reference in New Issue
Block a user