6 Commits

Author SHA1 Message Date
Alexander Engelsberger
ed83138e1f build: bump version 1.0.0a3 → 1.0.0a4 2022-06-12 11:52:06 +02:00
Alexander Engelsberger
1be7d7ec09 fix: dont save component initializer as hparm 2022-06-12 11:40:33 +02:00
Alexander Engelsberger
60d2a1d2c9 fix: dont save prototype initializer in yarch checkpoint 2022-06-12 11:12:55 +02:00
Alexander Engelsberger
be7d7f43bd fix: fix problems with y architecture and checkpoint 2022-06-12 10:36:15 +02:00
Alexander Engelsberger
fe729781fc build: bump version 1.0.0a2 → 1.0.0a3 2022-06-09 14:59:07 +02:00
Alexander Engelsberger
a7df7be1c8 feat: add confusion matrix callback 2022-06-09 14:55:59 +02:00
10 changed files with 173 additions and 78 deletions

View File

@@ -1,5 +1,5 @@
[bumpversion] [bumpversion]
current_version = 1.0.0a2 current_version = 1.0.0a4
commit = True commit = True
tag = True tag = True
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)((?P<release>[a-zA-Z0-9_.-]+))? parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)((?P<release>[a-zA-Z0-9_.-]+))?

View File

@@ -23,7 +23,7 @@ author = "Jensun Ravichandran"
# The full version, including alpha/beta/rc tags # The full version, including alpha/beta/rc tags
# #
release = "1.0.0-a2" release = "1.0.0-a4"
# -- General configuration --------------------------------------------------- # -- General configuration ---------------------------------------------------

View File

@@ -13,8 +13,8 @@ from torch.utils.data import DataLoader
# ############################################################################## # ##############################################################################
if __name__ == "__main__":
def main():
# ------------------------------------------------------------ # ------------------------------------------------------------
# DATA # DATA
# ------------------------------------------------------------ # ------------------------------------------------------------
@@ -51,7 +51,7 @@ if __name__ == "__main__":
# Create Model # Create Model
model = GMLVQ(hyperparameters) model = GMLVQ(hyperparameters)
print(model) print(model.hparams)
# ------------------------------------------------------------ # ------------------------------------------------------------
# TRAINING # TRAINING
@@ -74,15 +74,27 @@ if __name__ == "__main__":
vis = VisGMLVQ2D(data=train_ds) vis = VisGMLVQ2D(data=train_ds)
# Define trainer # Define trainer
trainer = pl.Trainer( trainer = pl.Trainer(callbacks=[
callbacks=[ vis,
vis, stopping_criterion,
stopping_criterion, es,
es, PlotLambdaMatrixToTensorboard(),
PlotLambdaMatrixToTensorboard(), ], )
],
max_epochs=1000,
)
# Train # Train
trainer.fit(model, train_loader) trainer.fit(model, train_loader)
# Manual save
trainer.save_checkpoint("./y_arch.ckpt")
# Load saved model
new_model = GMLVQ.load_from_checkpoint(
checkpoint_path="./y_arch.ckpt",
strict=True,
)
print(new_model.hparams)
if __name__ == "__main__":
main()

View File

@@ -36,4 +36,4 @@ from .unsupervised import (
) )
from .vis import * from .vis import *
__version__ = "1.0.0-a2" __version__ = "1.0.0-a4"

View File

@@ -3,17 +3,14 @@ Proto Y Architecture
Network architecture for Component based Learning. Network architecture for Component based Learning.
""" """
from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass
from typing import ( from typing import Any, Callable
Dict,
Set,
Type,
)
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
from torchmetrics import Metric from torchmetrics import Metric
from torchmetrics.classification.accuracy import Accuracy
class BaseYArchitecture(pl.LightningModule): class BaseYArchitecture(pl.LightningModule):
@@ -22,12 +19,25 @@ class BaseYArchitecture(pl.LightningModule):
class HyperParameters: class HyperParameters:
... ...
registered_metrics: Dict[Type[Metric], Metric] = {} # Fields
registered_metric_names: Dict[Type[Metric], Set[str]] = {} registered_metrics: dict[type[Metric], Metric] = {}
registered_metric_callbacks: dict[type[Metric], set[Callable]] = {}
# Type Hints for Necessary Fields
components_layer: torch.nn.Module components_layer: torch.nn.Module
def __init__(self, hparams) -> None: def __init__(self, hparams) -> None:
if type(hparams) is dict:
self.save_hyperparameters(hparams)
# TODO: => Move into Component Child
del hparams["initialized_proto_shape"]
hparams = self.HyperParameters(**hparams)
else:
self.save_hyperparameters(
hparams.__dict__,
ignore=["component_initializer"],
)
super().__init__() super().__init__()
# Common Steps # Common Steps
@@ -42,22 +52,6 @@ class BaseYArchitecture(pl.LightningModule):
# Inference Steps # Inference Steps
self.init_inference(hparams) self.init_inference(hparams)
# Initialize Model Metrics
self.init_model_metrics()
# internal API, called by models and callbacks
def register_torchmetric(
self,
name: str,
metric: Type[Metric],
**metric_kwargs,
):
if metric not in self.registered_metrics:
self.registered_metrics[metric] = metric(**metric_kwargs)
self.registered_metric_names[metric] = {name}
else:
self.registered_metric_names[metric].add(name)
# external API # external API
def get_competition(self, batch, components): def get_competition(self, batch, components):
latent_batch, latent_components = self.latent(batch, components) latent_batch, latent_components = self.latent(batch, components)
@@ -99,7 +93,6 @@ class BaseYArchitecture(pl.LightningModule):
return self.loss(comparison_tensor, batch, components) return self.loss(comparison_tensor, batch, components)
# Empty Initialization # Empty Initialization
# TODO: Type hints
# TODO: Docs # TODO: Docs
def init_components(self, hparams: HyperParameters) -> None: def init_components(self, hparams: HyperParameters) -> None:
... ...
@@ -119,9 +112,6 @@ class BaseYArchitecture(pl.LightningModule):
def init_inference(self, hparams: HyperParameters) -> None: def init_inference(self, hparams: HyperParameters) -> None:
... ...
def init_model_metrics(self) -> None:
self.register_torchmetric('accuracy', Accuracy)
# Empty Steps # Empty Steps
# TODO: Type hints # TODO: Type hints
def components(self): def components(self):
@@ -177,11 +167,26 @@ class BaseYArchitecture(pl.LightningModule):
raise NotImplementedError( raise NotImplementedError(
"The inference step has no reasonable default.") "The inference step has no reasonable default.")
def update_metrics_step(self, batch): # Y Architecture Hooks
x, y = batch
# internal API, called by models and callbacks
def register_torchmetric(
self,
name: Callable,
metric: type[Metric],
**metric_kwargs,
):
if metric not in self.registered_metrics:
self.registered_metrics[metric] = metric(**metric_kwargs)
self.registered_metric_callbacks[metric] = {name}
else:
self.registered_metric_callbacks[metric].add(name)
def update_metrics_step(self, batch):
# Prediction Metrics # Prediction Metrics
preds = self(x) preds = self(batch)
x, y = batch
for metric in self.registered_metrics: for metric in self.registered_metrics:
instance = self.registered_metrics[metric].to(self.device) instance = self.registered_metrics[metric].to(self.device)
instance(y, preds) instance(y, preds)
@@ -191,22 +196,31 @@ class BaseYArchitecture(pl.LightningModule):
instance = self.registered_metrics[metric].to(self.device) instance = self.registered_metrics[metric].to(self.device)
value = instance.compute() value = instance.compute()
for name in self.registered_metric_names[metric]: for callback in self.registered_metric_callbacks[metric]:
self.log(name, value) callback(value, self)
instance.reset() instance.reset()
# Lightning Hooks # Lightning Hooks
# Steps
def training_step(self, batch, batch_idx, optimizer_idx=None): def training_step(self, batch, batch_idx, optimizer_idx=None):
self.update_metrics_step(batch) self.update_metrics_step([torch.clone(el) for el in batch])
return self.loss_forward(batch) return self.loss_forward(batch)
def training_epoch_end(self, outs) -> None:
self.update_metrics_epoch()
def validation_step(self, batch, batch_idx): def validation_step(self, batch, batch_idx):
return self.loss_forward(batch) return self.loss_forward(batch)
def test_step(self, batch, batch_idx): def test_step(self, batch, batch_idx):
return self.loss_forward(batch) return self.loss_forward(batch)
# Other Hooks
def training_epoch_end(self, outs) -> None:
self.update_metrics_epoch()
def on_save_checkpoint(self, checkpoint: dict[str, Any]) -> None:
checkpoint["hyper_parameters"] = {
'hparams': checkpoint["hyper_parameters"]
}
return super().on_save_checkpoint(checkpoint)

View File

@@ -4,6 +4,7 @@ from prototorch.core.components import LabeledComponents
from prototorch.core.initializers import ( from prototorch.core.initializers import (
AbstractComponentsInitializer, AbstractComponentsInitializer,
LabelsInitializer, LabelsInitializer,
ZerosCompInitializer,
) )
from prototorch.y import BaseYArchitecture from prototorch.y import BaseYArchitecture
@@ -30,11 +31,21 @@ class SupervisedArchitecture(BaseYArchitecture):
# Steps # Steps
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------
def init_components(self, hparams: HyperParameters): def init_components(self, hparams: HyperParameters):
self.components_layer = LabeledComponents( if hparams.component_initializer is not None:
distribution=hparams.distribution, self.components_layer = LabeledComponents(
components_initializer=hparams.component_initializer, distribution=hparams.distribution,
labels_initializer=LabelsInitializer(), components_initializer=hparams.component_initializer,
) labels_initializer=LabelsInitializer(),
)
proto_shape = self.components_layer.components.shape[1:]
self.hparams["initialized_proto_shape"] = proto_shape
else:
# when restoring a checkpointed model
self.components_layer = LabeledComponents(
distribution=hparams.distribution,
components_initializer=ZerosCompInitializer(
self.hparams["initialized_proto_shape"]),
)
# Properties # Properties
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------

View File

@@ -24,17 +24,11 @@ class SingleLearningRateMixin(BaseYArchitecture):
lr: float = 0.1 lr: float = 0.1
optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam
# Steps
# ----------------------------------------------------------------------------------------------------
def __init__(self, hparams: HyperParameters) -> None:
super().__init__(hparams)
self.lr = hparams.lr
self.optimizer = hparams.optimizer
# Hooks # Hooks
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------
def configure_optimizers(self): def configure_optimizers(self):
return self.optimizer(self.parameters(), lr=self.lr) # type: ignore return self.hparams.optimizer(self.parameters(),
lr=self.hparams.lr) # type: ignore
class MultipleLearningRateMixin(BaseYArchitecture): class MultipleLearningRateMixin(BaseYArchitecture):
@@ -55,31 +49,24 @@ class MultipleLearningRateMixin(BaseYArchitecture):
lr: dict = field(default_factory=lambda: dict()) lr: dict = field(default_factory=lambda: dict())
optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam
# Steps
# ----------------------------------------------------------------------------------------------------
def __init__(self, hparams: HyperParameters) -> None:
super().__init__(hparams)
self.lr = hparams.lr
self.optimizer = hparams.optimizer
# Hooks # Hooks
# ---------------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------------
def configure_optimizers(self): def configure_optimizers(self):
optimizers = [] optimizers = []
for name, lr in self.lr.items(): for name, lr in self.hparams.lr.items():
if not hasattr(self, name): if not hasattr(self, name):
raise ValueError(f"{name} is not a parameter of {self}") raise ValueError(f"{name} is not a parameter of {self}")
else: else:
model_part = getattr(self, name) model_part = getattr(self, name)
if isinstance(model_part, Parameter): if isinstance(model_part, Parameter):
optimizers.append( optimizers.append(
self.optimizer( self.hparams.optimizer(
[model_part], [model_part],
lr=lr, # type: ignore lr=lr, # type: ignore
)) ))
elif hasattr(model_part, "parameters"): elif hasattr(model_part, "parameters"):
optimizers.append( optimizers.append(
self.optimizer( self.hparams.optimizer(
model_part.parameters(), model_part.parameters(),
lr=lr, # type: ignore lr=lr, # type: ignore
)) ))

View File

@@ -13,8 +13,18 @@ from prototorch.y.library.gmlvq import GMLVQ
from pytorch_lightning.loggers import TensorBoardLogger from pytorch_lightning.loggers import TensorBoardLogger
DIVERGING_COLOR_MAPS = [ DIVERGING_COLOR_MAPS = [
'PiYG', 'PRGn', 'BrBG', 'PuOr', 'RdGy', 'RdBu', 'RdYlBu', 'RdYlGn', 'PiYG',
'Spectral', 'coolwarm', 'bwr', 'seismic' 'PRGn',
'BrBG',
'PuOr',
'RdGy',
'RdBu',
'RdYlBu',
'RdYlGn',
'Spectral',
'coolwarm',
'bwr',
'seismic',
] ]
@@ -40,13 +50,72 @@ class LogTorchmetricCallback(pl.Callback):
) -> None: ) -> None:
if self.on == "prediction": if self.on == "prediction":
pl_module.register_torchmetric( pl_module.register_torchmetric(
self.name, self,
self.metric, self.metric,
**self.metric_kwargs, **self.metric_kwargs,
) )
else: else:
raise ValueError(f"{self.on} is no valid metric hook") raise ValueError(f"{self.on} is no valid metric hook")
def __call__(self, value, pl_module: BaseYArchitecture):
pl_module.log(self.name, value)
class LogConfusionMatrix(LogTorchmetricCallback):
def __init__(
self,
num_classes,
name="confusion",
on='prediction',
**kwargs,
):
super().__init__(
name,
torchmetrics.ConfusionMatrix,
on=on,
num_classes=num_classes,
**kwargs,
)
def __call__(self, value, pl_module: BaseYArchitecture):
fig, ax = plt.subplots()
ax.imshow(value.detach().cpu().numpy())
# Show all ticks and label them with the respective list entries
# ax.set_xticks(np.arange(len(farmers)), labels=farmers)
# ax.set_yticks(np.arange(len(vegetables)), labels=vegetables)
# Rotate the tick labels and set their alignment.
plt.setp(
ax.get_xticklabels(),
rotation=45,
ha="right",
rotation_mode="anchor",
)
# Loop over data dimensions and create text annotations.
for i in range(len(value)):
for j in range(len(value)):
text = ax.text(
j,
i,
value[i, j].item(),
ha="center",
va="center",
color="w",
)
ax.set_title(self.name)
fig.tight_layout()
pl_module.logger.experiment.add_figure(
tag=self.name,
figure=fig,
close=True,
global_step=pl_module.global_step,
)
class VisGLVQ2D(Vis2DAbstract): class VisGLVQ2D(Vis2DAbstract):

View File

@@ -1,5 +1,7 @@
from .glvq import GLVQ from .glvq import GLVQ
from .gmlvq import GMLVQ
__all__ = [ __all__ = [
"GLVQ", "GLVQ",
"GMLVQ",
] ]

View File

@@ -55,7 +55,7 @@ ALL = CLI + DEV + DOCS + EXAMPLES + TESTS
setup( setup(
name=safe_name("prototorch_" + PLUGIN_NAME), name=safe_name("prototorch_" + PLUGIN_NAME),
version="1.0.0-a2", version="1.0.0-a4",
description="Pre-packaged prototype-based " description="Pre-packaged prototype-based "
"machine learning models using ProtoTorch and PyTorch-Lightning.", "machine learning models using ProtoTorch and PyTorch-Lightning.",
long_description=long_description, long_description=long_description,