Compare commits
8 Commits
Author | SHA1 | Date | |
---|---|---|---|
|
46ec7b07d7 | ||
|
07dab5a5ca | ||
|
ed83138e1f | ||
|
1be7d7ec09 | ||
|
60d2a1d2c9 | ||
|
be7d7f43bd | ||
|
fe729781fc | ||
|
a7df7be1c8 |
@@ -1,5 +1,5 @@
|
|||||||
[bumpversion]
|
[bumpversion]
|
||||||
current_version = 1.0.0a2
|
current_version = 1.0.0a5
|
||||||
commit = True
|
commit = True
|
||||||
tag = True
|
tag = True
|
||||||
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)((?P<release>[a-zA-Z0-9_.-]+))?
|
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)((?P<release>[a-zA-Z0-9_.-]+))?
|
||||||
|
@@ -23,7 +23,7 @@ author = "Jensun Ravichandran"
|
|||||||
|
|
||||||
# The full version, including alpha/beta/rc tags
|
# The full version, including alpha/beta/rc tags
|
||||||
#
|
#
|
||||||
release = "1.0.0-a2"
|
release = "1.0.0-a5"
|
||||||
|
|
||||||
# -- General configuration ---------------------------------------------------
|
# -- General configuration ---------------------------------------------------
|
||||||
|
|
||||||
|
@@ -13,8 +13,8 @@ from torch.utils.data import DataLoader
|
|||||||
|
|
||||||
# ##############################################################################
|
# ##############################################################################
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
|
|
||||||
|
def main():
|
||||||
# ------------------------------------------------------------
|
# ------------------------------------------------------------
|
||||||
# DATA
|
# DATA
|
||||||
# ------------------------------------------------------------
|
# ------------------------------------------------------------
|
||||||
@@ -51,7 +51,7 @@ if __name__ == "__main__":
|
|||||||
# Create Model
|
# Create Model
|
||||||
model = GMLVQ(hyperparameters)
|
model = GMLVQ(hyperparameters)
|
||||||
|
|
||||||
print(model)
|
print(model.hparams)
|
||||||
|
|
||||||
# ------------------------------------------------------------
|
# ------------------------------------------------------------
|
||||||
# TRAINING
|
# TRAINING
|
||||||
@@ -74,15 +74,27 @@ if __name__ == "__main__":
|
|||||||
vis = VisGMLVQ2D(data=train_ds)
|
vis = VisGMLVQ2D(data=train_ds)
|
||||||
|
|
||||||
# Define trainer
|
# Define trainer
|
||||||
trainer = pl.Trainer(
|
trainer = pl.Trainer(callbacks=[
|
||||||
callbacks=[
|
vis,
|
||||||
vis,
|
stopping_criterion,
|
||||||
stopping_criterion,
|
es,
|
||||||
es,
|
PlotLambdaMatrixToTensorboard(),
|
||||||
PlotLambdaMatrixToTensorboard(),
|
], )
|
||||||
],
|
|
||||||
max_epochs=1000,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Train
|
# Train
|
||||||
trainer.fit(model, train_loader)
|
trainer.fit(model, train_loader)
|
||||||
|
|
||||||
|
# Manual save
|
||||||
|
trainer.save_checkpoint("./y_arch.ckpt")
|
||||||
|
|
||||||
|
# Load saved model
|
||||||
|
new_model = GMLVQ.load_from_checkpoint(
|
||||||
|
checkpoint_path="./y_arch.ckpt",
|
||||||
|
strict=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(new_model.hparams)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
@@ -36,4 +36,4 @@ from .unsupervised import (
|
|||||||
)
|
)
|
||||||
from .vis import *
|
from .vis import *
|
||||||
|
|
||||||
__version__ = "1.0.0-a2"
|
__version__ = "1.0.0-a5"
|
||||||
|
@@ -3,17 +3,14 @@ Proto Y Architecture
|
|||||||
|
|
||||||
Network architecture for Component based Learning.
|
Network architecture for Component based Learning.
|
||||||
"""
|
"""
|
||||||
from dataclasses import dataclass
|
from __future__ import annotations
|
||||||
from typing import (
|
|
||||||
Dict,
|
from dataclasses import asdict, dataclass
|
||||||
Set,
|
from typing import Any, Callable
|
||||||
Type,
|
|
||||||
)
|
|
||||||
|
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import torch
|
import torch
|
||||||
from torchmetrics import Metric
|
from torchmetrics import Metric
|
||||||
from torchmetrics.classification.accuracy import Accuracy
|
|
||||||
|
|
||||||
|
|
||||||
class BaseYArchitecture(pl.LightningModule):
|
class BaseYArchitecture(pl.LightningModule):
|
||||||
@@ -22,12 +19,24 @@ class BaseYArchitecture(pl.LightningModule):
|
|||||||
class HyperParameters:
|
class HyperParameters:
|
||||||
...
|
...
|
||||||
|
|
||||||
registered_metrics: Dict[Type[Metric], Metric] = {}
|
# Fields
|
||||||
registered_metric_names: Dict[Type[Metric], Set[str]] = {}
|
registered_metrics: dict[type[Metric], Metric] = {}
|
||||||
|
registered_metric_callbacks: dict[type[Metric], set[Callable]] = {}
|
||||||
|
|
||||||
|
# Type Hints for Necessary Fields
|
||||||
components_layer: torch.nn.Module
|
components_layer: torch.nn.Module
|
||||||
|
|
||||||
def __init__(self, hparams) -> None:
|
def __init__(self, hparams) -> None:
|
||||||
|
if type(hparams) is dict:
|
||||||
|
self.save_hyperparameters(hparams)
|
||||||
|
# TODO: => Move into Component Child
|
||||||
|
del hparams["initialized_proto_shape"]
|
||||||
|
hparams = self.HyperParameters(**hparams)
|
||||||
|
else:
|
||||||
|
hparam_dict = asdict(hparams)
|
||||||
|
hparam_dict["component_initializer"] = None
|
||||||
|
self.save_hyperparameters(hparam_dict, )
|
||||||
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
# Common Steps
|
# Common Steps
|
||||||
@@ -42,22 +51,6 @@ class BaseYArchitecture(pl.LightningModule):
|
|||||||
# Inference Steps
|
# Inference Steps
|
||||||
self.init_inference(hparams)
|
self.init_inference(hparams)
|
||||||
|
|
||||||
# Initialize Model Metrics
|
|
||||||
self.init_model_metrics()
|
|
||||||
|
|
||||||
# internal API, called by models and callbacks
|
|
||||||
def register_torchmetric(
|
|
||||||
self,
|
|
||||||
name: str,
|
|
||||||
metric: Type[Metric],
|
|
||||||
**metric_kwargs,
|
|
||||||
):
|
|
||||||
if metric not in self.registered_metrics:
|
|
||||||
self.registered_metrics[metric] = metric(**metric_kwargs)
|
|
||||||
self.registered_metric_names[metric] = {name}
|
|
||||||
else:
|
|
||||||
self.registered_metric_names[metric].add(name)
|
|
||||||
|
|
||||||
# external API
|
# external API
|
||||||
def get_competition(self, batch, components):
|
def get_competition(self, batch, components):
|
||||||
latent_batch, latent_components = self.latent(batch, components)
|
latent_batch, latent_components = self.latent(batch, components)
|
||||||
@@ -99,7 +92,6 @@ class BaseYArchitecture(pl.LightningModule):
|
|||||||
return self.loss(comparison_tensor, batch, components)
|
return self.loss(comparison_tensor, batch, components)
|
||||||
|
|
||||||
# Empty Initialization
|
# Empty Initialization
|
||||||
# TODO: Type hints
|
|
||||||
# TODO: Docs
|
# TODO: Docs
|
||||||
def init_components(self, hparams: HyperParameters) -> None:
|
def init_components(self, hparams: HyperParameters) -> None:
|
||||||
...
|
...
|
||||||
@@ -119,9 +111,6 @@ class BaseYArchitecture(pl.LightningModule):
|
|||||||
def init_inference(self, hparams: HyperParameters) -> None:
|
def init_inference(self, hparams: HyperParameters) -> None:
|
||||||
...
|
...
|
||||||
|
|
||||||
def init_model_metrics(self) -> None:
|
|
||||||
self.register_torchmetric('accuracy', Accuracy)
|
|
||||||
|
|
||||||
# Empty Steps
|
# Empty Steps
|
||||||
# TODO: Type hints
|
# TODO: Type hints
|
||||||
def components(self):
|
def components(self):
|
||||||
@@ -177,11 +166,26 @@ class BaseYArchitecture(pl.LightningModule):
|
|||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"The inference step has no reasonable default.")
|
"The inference step has no reasonable default.")
|
||||||
|
|
||||||
def update_metrics_step(self, batch):
|
# Y Architecture Hooks
|
||||||
x, y = batch
|
|
||||||
|
|
||||||
|
# internal API, called by models and callbacks
|
||||||
|
def register_torchmetric(
|
||||||
|
self,
|
||||||
|
name: Callable,
|
||||||
|
metric: type[Metric],
|
||||||
|
**metric_kwargs,
|
||||||
|
):
|
||||||
|
if metric not in self.registered_metrics:
|
||||||
|
self.registered_metrics[metric] = metric(**metric_kwargs)
|
||||||
|
self.registered_metric_callbacks[metric] = {name}
|
||||||
|
else:
|
||||||
|
self.registered_metric_callbacks[metric].add(name)
|
||||||
|
|
||||||
|
def update_metrics_step(self, batch):
|
||||||
# Prediction Metrics
|
# Prediction Metrics
|
||||||
preds = self(x)
|
preds = self(batch)
|
||||||
|
|
||||||
|
x, y = batch
|
||||||
for metric in self.registered_metrics:
|
for metric in self.registered_metrics:
|
||||||
instance = self.registered_metrics[metric].to(self.device)
|
instance = self.registered_metrics[metric].to(self.device)
|
||||||
instance(y, preds)
|
instance(y, preds)
|
||||||
@@ -191,22 +195,31 @@ class BaseYArchitecture(pl.LightningModule):
|
|||||||
instance = self.registered_metrics[metric].to(self.device)
|
instance = self.registered_metrics[metric].to(self.device)
|
||||||
value = instance.compute()
|
value = instance.compute()
|
||||||
|
|
||||||
for name in self.registered_metric_names[metric]:
|
for callback in self.registered_metric_callbacks[metric]:
|
||||||
self.log(name, value)
|
callback(value, self)
|
||||||
|
|
||||||
instance.reset()
|
instance.reset()
|
||||||
|
|
||||||
# Lightning Hooks
|
# Lightning Hooks
|
||||||
|
|
||||||
|
# Steps
|
||||||
def training_step(self, batch, batch_idx, optimizer_idx=None):
|
def training_step(self, batch, batch_idx, optimizer_idx=None):
|
||||||
self.update_metrics_step(batch)
|
self.update_metrics_step([torch.clone(el) for el in batch])
|
||||||
|
|
||||||
return self.loss_forward(batch)
|
return self.loss_forward(batch)
|
||||||
|
|
||||||
def training_epoch_end(self, outs) -> None:
|
|
||||||
self.update_metrics_epoch()
|
|
||||||
|
|
||||||
def validation_step(self, batch, batch_idx):
|
def validation_step(self, batch, batch_idx):
|
||||||
return self.loss_forward(batch)
|
return self.loss_forward(batch)
|
||||||
|
|
||||||
def test_step(self, batch, batch_idx):
|
def test_step(self, batch, batch_idx):
|
||||||
return self.loss_forward(batch)
|
return self.loss_forward(batch)
|
||||||
|
|
||||||
|
# Other Hooks
|
||||||
|
def training_epoch_end(self, outs) -> None:
|
||||||
|
self.update_metrics_epoch()
|
||||||
|
|
||||||
|
def on_save_checkpoint(self, checkpoint: dict[str, Any]) -> None:
|
||||||
|
checkpoint["hyper_parameters"] = {
|
||||||
|
'hparams': checkpoint["hyper_parameters"]
|
||||||
|
}
|
||||||
|
return super().on_save_checkpoint(checkpoint)
|
||||||
|
@@ -4,6 +4,7 @@ from prototorch.core.components import LabeledComponents
|
|||||||
from prototorch.core.initializers import (
|
from prototorch.core.initializers import (
|
||||||
AbstractComponentsInitializer,
|
AbstractComponentsInitializer,
|
||||||
LabelsInitializer,
|
LabelsInitializer,
|
||||||
|
ZerosCompInitializer,
|
||||||
)
|
)
|
||||||
from prototorch.y import BaseYArchitecture
|
from prototorch.y import BaseYArchitecture
|
||||||
|
|
||||||
@@ -30,11 +31,21 @@ class SupervisedArchitecture(BaseYArchitecture):
|
|||||||
# Steps
|
# Steps
|
||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
||||||
def init_components(self, hparams: HyperParameters):
|
def init_components(self, hparams: HyperParameters):
|
||||||
self.components_layer = LabeledComponents(
|
if hparams.component_initializer is not None:
|
||||||
distribution=hparams.distribution,
|
self.components_layer = LabeledComponents(
|
||||||
components_initializer=hparams.component_initializer,
|
distribution=hparams.distribution,
|
||||||
labels_initializer=LabelsInitializer(),
|
components_initializer=hparams.component_initializer,
|
||||||
)
|
labels_initializer=LabelsInitializer(),
|
||||||
|
)
|
||||||
|
proto_shape = self.components_layer.components.shape[1:]
|
||||||
|
self.hparams["initialized_proto_shape"] = proto_shape
|
||||||
|
else:
|
||||||
|
# when restoring a checkpointed model
|
||||||
|
self.components_layer = LabeledComponents(
|
||||||
|
distribution=hparams.distribution,
|
||||||
|
components_initializer=ZerosCompInitializer(
|
||||||
|
self.hparams["initialized_proto_shape"]),
|
||||||
|
)
|
||||||
|
|
||||||
# Properties
|
# Properties
|
||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
||||||
|
@@ -24,17 +24,11 @@ class SingleLearningRateMixin(BaseYArchitecture):
|
|||||||
lr: float = 0.1
|
lr: float = 0.1
|
||||||
optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam
|
optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam
|
||||||
|
|
||||||
# Steps
|
|
||||||
# ----------------------------------------------------------------------------------------------------
|
|
||||||
def __init__(self, hparams: HyperParameters) -> None:
|
|
||||||
super().__init__(hparams)
|
|
||||||
self.lr = hparams.lr
|
|
||||||
self.optimizer = hparams.optimizer
|
|
||||||
|
|
||||||
# Hooks
|
# Hooks
|
||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
||||||
def configure_optimizers(self):
|
def configure_optimizers(self):
|
||||||
return self.optimizer(self.parameters(), lr=self.lr) # type: ignore
|
return self.hparams.optimizer(self.parameters(),
|
||||||
|
lr=self.hparams.lr) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
class MultipleLearningRateMixin(BaseYArchitecture):
|
class MultipleLearningRateMixin(BaseYArchitecture):
|
||||||
@@ -55,31 +49,24 @@ class MultipleLearningRateMixin(BaseYArchitecture):
|
|||||||
lr: dict = field(default_factory=lambda: dict())
|
lr: dict = field(default_factory=lambda: dict())
|
||||||
optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam
|
optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam
|
||||||
|
|
||||||
# Steps
|
|
||||||
# ----------------------------------------------------------------------------------------------------
|
|
||||||
def __init__(self, hparams: HyperParameters) -> None:
|
|
||||||
super().__init__(hparams)
|
|
||||||
self.lr = hparams.lr
|
|
||||||
self.optimizer = hparams.optimizer
|
|
||||||
|
|
||||||
# Hooks
|
# Hooks
|
||||||
# ----------------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------------
|
||||||
def configure_optimizers(self):
|
def configure_optimizers(self):
|
||||||
optimizers = []
|
optimizers = []
|
||||||
for name, lr in self.lr.items():
|
for name, lr in self.hparams.lr.items():
|
||||||
if not hasattr(self, name):
|
if not hasattr(self, name):
|
||||||
raise ValueError(f"{name} is not a parameter of {self}")
|
raise ValueError(f"{name} is not a parameter of {self}")
|
||||||
else:
|
else:
|
||||||
model_part = getattr(self, name)
|
model_part = getattr(self, name)
|
||||||
if isinstance(model_part, Parameter):
|
if isinstance(model_part, Parameter):
|
||||||
optimizers.append(
|
optimizers.append(
|
||||||
self.optimizer(
|
self.hparams.optimizer(
|
||||||
[model_part],
|
[model_part],
|
||||||
lr=lr, # type: ignore
|
lr=lr, # type: ignore
|
||||||
))
|
))
|
||||||
elif hasattr(model_part, "parameters"):
|
elif hasattr(model_part, "parameters"):
|
||||||
optimizers.append(
|
optimizers.append(
|
||||||
self.optimizer(
|
self.hparams.optimizer(
|
||||||
model_part.parameters(),
|
model_part.parameters(),
|
||||||
lr=lr, # type: ignore
|
lr=lr, # type: ignore
|
||||||
))
|
))
|
||||||
|
@@ -13,8 +13,18 @@ from prototorch.y.library.gmlvq import GMLVQ
|
|||||||
from pytorch_lightning.loggers import TensorBoardLogger
|
from pytorch_lightning.loggers import TensorBoardLogger
|
||||||
|
|
||||||
DIVERGING_COLOR_MAPS = [
|
DIVERGING_COLOR_MAPS = [
|
||||||
'PiYG', 'PRGn', 'BrBG', 'PuOr', 'RdGy', 'RdBu', 'RdYlBu', 'RdYlGn',
|
'PiYG',
|
||||||
'Spectral', 'coolwarm', 'bwr', 'seismic'
|
'PRGn',
|
||||||
|
'BrBG',
|
||||||
|
'PuOr',
|
||||||
|
'RdGy',
|
||||||
|
'RdBu',
|
||||||
|
'RdYlBu',
|
||||||
|
'RdYlGn',
|
||||||
|
'Spectral',
|
||||||
|
'coolwarm',
|
||||||
|
'bwr',
|
||||||
|
'seismic',
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@@ -40,13 +50,72 @@ class LogTorchmetricCallback(pl.Callback):
|
|||||||
) -> None:
|
) -> None:
|
||||||
if self.on == "prediction":
|
if self.on == "prediction":
|
||||||
pl_module.register_torchmetric(
|
pl_module.register_torchmetric(
|
||||||
self.name,
|
self,
|
||||||
self.metric,
|
self.metric,
|
||||||
**self.metric_kwargs,
|
**self.metric_kwargs,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"{self.on} is no valid metric hook")
|
raise ValueError(f"{self.on} is no valid metric hook")
|
||||||
|
|
||||||
|
def __call__(self, value, pl_module: BaseYArchitecture):
|
||||||
|
pl_module.log(self.name, value)
|
||||||
|
|
||||||
|
|
||||||
|
class LogConfusionMatrix(LogTorchmetricCallback):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_classes,
|
||||||
|
name="confusion",
|
||||||
|
on='prediction',
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
name,
|
||||||
|
torchmetrics.ConfusionMatrix,
|
||||||
|
on=on,
|
||||||
|
num_classes=num_classes,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
def __call__(self, value, pl_module: BaseYArchitecture):
|
||||||
|
fig, ax = plt.subplots()
|
||||||
|
ax.imshow(value.detach().cpu().numpy())
|
||||||
|
|
||||||
|
# Show all ticks and label them with the respective list entries
|
||||||
|
# ax.set_xticks(np.arange(len(farmers)), labels=farmers)
|
||||||
|
# ax.set_yticks(np.arange(len(vegetables)), labels=vegetables)
|
||||||
|
|
||||||
|
# Rotate the tick labels and set their alignment.
|
||||||
|
plt.setp(
|
||||||
|
ax.get_xticklabels(),
|
||||||
|
rotation=45,
|
||||||
|
ha="right",
|
||||||
|
rotation_mode="anchor",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Loop over data dimensions and create text annotations.
|
||||||
|
for i in range(len(value)):
|
||||||
|
for j in range(len(value)):
|
||||||
|
text = ax.text(
|
||||||
|
j,
|
||||||
|
i,
|
||||||
|
value[i, j].item(),
|
||||||
|
ha="center",
|
||||||
|
va="center",
|
||||||
|
color="w",
|
||||||
|
)
|
||||||
|
|
||||||
|
ax.set_title(self.name)
|
||||||
|
fig.tight_layout()
|
||||||
|
|
||||||
|
pl_module.logger.experiment.add_figure(
|
||||||
|
tag=self.name,
|
||||||
|
figure=fig,
|
||||||
|
close=True,
|
||||||
|
global_step=pl_module.global_step,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class VisGLVQ2D(Vis2DAbstract):
|
class VisGLVQ2D(Vis2DAbstract):
|
||||||
|
|
||||||
|
@@ -1,5 +1,7 @@
|
|||||||
from .glvq import GLVQ
|
from .glvq import GLVQ
|
||||||
|
from .gmlvq import GMLVQ
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"GLVQ",
|
"GLVQ",
|
||||||
|
"GMLVQ",
|
||||||
]
|
]
|
||||||
|
2
setup.py
2
setup.py
@@ -55,7 +55,7 @@ ALL = CLI + DEV + DOCS + EXAMPLES + TESTS
|
|||||||
|
|
||||||
setup(
|
setup(
|
||||||
name=safe_name("prototorch_" + PLUGIN_NAME),
|
name=safe_name("prototorch_" + PLUGIN_NAME),
|
||||||
version="1.0.0-a2",
|
version="1.0.0-a5",
|
||||||
description="Pre-packaged prototype-based "
|
description="Pre-packaged prototype-based "
|
||||||
"machine learning models using ProtoTorch and PyTorch-Lightning.",
|
"machine learning models using ProtoTorch and PyTorch-Lightning.",
|
||||||
long_description=long_description,
|
long_description=long_description,
|
||||||
|
Reference in New Issue
Block a user