11 Commits

Author SHA1 Message Date
Alexander Engelsberger
9bb2e20dce build: bump version 1.0.0a7 → 1.0.0a8 2022-10-26 14:53:52 +02:00
Alexander Engelsberger
6748951b63 ci: temporarily remove 3.11 2022-10-26 13:31:52 +02:00
Alexander Engelsberger
c547af728b ci: add refurb to pre-commit config 2022-10-26 13:19:45 +02:00
Alexander Engelsberger
482044ec87 ci: update pre-commit configuration 2022-10-26 13:03:15 +02:00
Alexander Engelsberger
45f01f39d4 ci: add python 3.11 to ci 2022-10-26 12:58:05 +02:00
Alexander Engelsberger
9ab864fbdf chore: add simple test to fix github action 2022-10-26 12:57:45 +02:00
Alexander Engelsberger
365e0fb931 feat: add useful callbacks for GMLVQ
omega trace normalization and matrix profile visualization
2022-09-21 13:23:43 +02:00
Alexander Engelsberger
ba50dfba50 fix: accuracy as torchmetric fixed 2022-09-21 10:22:35 +02:00
Alexander Engelsberger
16ca409f07 feat: metric callback defaults on epoch 2022-08-26 10:58:33 +02:00
Alexander Engelsberger
c3cad19853 build: bump version 1.0.0a6 → 1.0.0a7 2022-08-19 12:17:32 +02:00
Alexander Engelsberger
ec294bdd37 feat: add omega parameter api 2022-08-19 12:15:11 +02:00
13 changed files with 174 additions and 37 deletions

View File

@@ -1,5 +1,5 @@
[bumpversion] [bumpversion]
current_version = 1.0.0a6 current_version = 1.0.0a8
commit = True commit = True
tag = True tag = True
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)((?P<release>[a-zA-Z0-9_.-]+))? parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)((?P<release>[a-zA-Z0-9_.-]+))?

View File

@@ -21,7 +21,7 @@ jobs:
run: | run: |
python -m pip install --upgrade pip python -m pip install --upgrade pip
pip install .[all] pip install .[all]
- uses: pre-commit/action@v2.0.3 - uses: pre-commit/action@v3.0.0
compatibility: compatibility:
needs: style needs: style
strategy: strategy:
@@ -36,7 +36,8 @@ jobs:
python-version: "3.8" python-version: "3.8"
- os: windows-latest - os: windows-latest
python-version: "3.9" python-version: "3.9"
- os: windows-latest
python-version: "3.11"
runs-on: ${{ matrix.os }} runs-on: ${{ matrix.os }}
steps: steps:
- uses: actions/checkout@v2 - uses: actions/checkout@v2

View File

@@ -3,7 +3,7 @@
repos: repos:
- repo: https://github.com/pre-commit/pre-commit-hooks - repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.2.0 rev: v4.3.0
hooks: hooks:
- id: trailing-whitespace - id: trailing-whitespace
exclude: (^\.bumpversion\.cfg$|cli_messages\.py) exclude: (^\.bumpversion\.cfg$|cli_messages\.py)
@@ -14,7 +14,7 @@ repos:
- id: check-case-conflict - id: check-case-conflict
- repo: https://github.com/myint/autoflake - repo: https://github.com/myint/autoflake
rev: v1.4 rev: v1.7.7
hooks: hooks:
- id: autoflake - id: autoflake
@@ -24,7 +24,7 @@ repos:
- id: isort - id: isort
- repo: https://github.com/pre-commit/mirrors-mypy - repo: https://github.com/pre-commit/mirrors-mypy
rev: v0.950 rev: v0.982
hooks: hooks:
- id: mypy - id: mypy
files: prototorch files: prototorch
@@ -43,7 +43,7 @@ repos:
- id: python-check-blanket-noqa - id: python-check-blanket-noqa
- repo: https://github.com/asottile/pyupgrade - repo: https://github.com/asottile/pyupgrade
rev: v2.32.1 rev: v3.1.0
hooks: hooks:
- id: pyupgrade - id: pyupgrade
@@ -52,3 +52,8 @@ repos:
hooks: hooks:
- id: gitlint - id: gitlint
args: [--contrib=CT1, --ignore=B6, --msg-filename] args: [--contrib=CT1, --ignore=B6, --msg-filename]
- repo: https://github.com/dosisod/refurb
rev: v1.4.0
hooks:
- id: refurb

View File

@@ -23,7 +23,7 @@ author = "Jensun Ravichandran"
# The full version, including alpha/beta/rc tags # The full version, including alpha/beta/rc tags
# #
release = "1.0.0-a6" release = "1.0.0-a8"
# -- General configuration --------------------------------------------------- # -- General configuration ---------------------------------------------------

View File

@@ -97,6 +97,13 @@ def main():
step=Steps.VALIDATION, step=Steps.VALIDATION,
) )
accuracy = LogTorchmetricCallback(
'validation_accuracy',
torchmetrics.Accuracy,
num_classes=3,
step=Steps.VALIDATION,
)
es = EarlyStopping( es = EarlyStopping(
monitor=stopping_criterion.name, monitor=stopping_criterion.name,
mode="max", mode="max",
@@ -111,6 +118,7 @@ def main():
callbacks=[ callbacks=[
vis, vis,
recall, recall,
accuracy,
stopping_criterion, stopping_criterion,
es, es,
PlotLambdaMatrixToTensorboard(), PlotLambdaMatrixToTensorboard(),

View File

@@ -22,4 +22,4 @@ __all__ = [
"GLVQLossMixin", "GLVQLossMixin",
] ]
__version__ = "1.0.0-a6" __version__ = "1.0.0-a8"

View File

@@ -46,15 +46,15 @@ class BaseYArchitecture(pl.LightningModule):
components_layer: torch.nn.Module components_layer: torch.nn.Module
def __init__(self, hparams) -> None: def __init__(self, hparams) -> None:
if type(hparams) is dict: if isinstance(hparams, dict):
self.save_hyperparameters(hparams) self.save_hyperparameters(hparams)
# TODO: => Move into Component Child # TODO: => Move into Component Child
del hparams["initialized_proto_shape"] del hparams["initialized_proto_shape"]
hparams = self.HyperParameters(**hparams) hparams = self.HyperParameters(**hparams)
else: else:
hparam_dict = asdict(hparams) hparams_dict = asdict(hparams)
hparam_dict["component_initializer"] = None hparams_dict["component_initializer"] = None
self.save_hyperparameters(hparam_dict, ) self.save_hyperparameters(hparams_dict, )
super().__init__() super().__init__()
@@ -72,6 +72,9 @@ class BaseYArchitecture(pl.LightningModule):
# external API # external API
def get_competition(self, batch, components): def get_competition(self, batch, components):
'''
Returns the output of the competition layer.
'''
latent_batch, latent_components = self.backbone(batch, components) latent_batch, latent_components = self.backbone(batch, components)
# TODO: => Latent Hook # TODO: => Latent Hook
comparison_tensor = self.comparison(latent_batch, latent_components) comparison_tensor = self.comparison(latent_batch, latent_components)
@@ -79,6 +82,9 @@ class BaseYArchitecture(pl.LightningModule):
return comparison_tensor return comparison_tensor
def forward(self, batch): def forward(self, batch):
'''
Returns the prediction.
'''
if isinstance(batch, torch.Tensor): if isinstance(batch, torch.Tensor):
batch = (batch, None) batch = (batch, None)
# TODO: manage different datatypes? # TODO: manage different datatypes?
@@ -95,6 +101,9 @@ class BaseYArchitecture(pl.LightningModule):
return self.forward(batch) return self.forward(batch)
def forward_comparison(self, batch): def forward_comparison(self, batch):
'''
Returns the Output of the comparison layer.
'''
if isinstance(batch, torch.Tensor): if isinstance(batch, torch.Tensor):
batch = (batch, None) batch = (batch, None)
# TODO: manage different datatypes? # TODO: manage different datatypes?
@@ -103,6 +112,9 @@ class BaseYArchitecture(pl.LightningModule):
return self.get_competition(batch, components) return self.get_competition(batch, components)
def loss_forward(self, batch): def loss_forward(self, batch):
'''
Returns the output of the loss layer.
'''
# TODO: manage different datatypes? # TODO: manage different datatypes?
components = self.components_layer() components = self.components_layer()
# TODO: => Component Hook # TODO: => Component Hook
@@ -115,37 +127,31 @@ class BaseYArchitecture(pl.LightningModule):
""" """
All initialization necessary for the components step. All initialization necessary for the components step.
""" """
...
def init_backbone(self, hparams: HyperParameters) -> None: def init_backbone(self, hparams: HyperParameters) -> None:
""" """
All initialization necessary for the backbone step. All initialization necessary for the backbone step.
""" """
...
def init_comparison(self, hparams: HyperParameters) -> None: def init_comparison(self, hparams: HyperParameters) -> None:
""" """
All initialization necessary for the comparison step. All initialization necessary for the comparison step.
""" """
...
def init_competition(self, hparams: HyperParameters) -> None: def init_competition(self, hparams: HyperParameters) -> None:
""" """
All initialization necessary for the competition step. All initialization necessary for the competition step.
""" """
...
def init_loss(self, hparams: HyperParameters) -> None: def init_loss(self, hparams: HyperParameters) -> None:
""" """
All initialization necessary for the loss step. All initialization necessary for the loss step.
""" """
...
def init_inference(self, hparams: HyperParameters) -> None: def init_inference(self, hparams: HyperParameters) -> None:
""" """
All initialization necessary for the inference step. All initialization necessary for the inference step.
""" """
...
# Empty Steps # Empty Steps
def components(self): def components(self):
@@ -162,7 +168,8 @@ class BaseYArchitecture(pl.LightningModule):
The backbone step receives the data batch and the components. The backbone step receives the data batch and the components.
It can transform both by an arbitrary function. It can transform both by an arbitrary function.
It returns the transformed batch and components, each of the same length as the original input. It returns the transformed batch and components,
each of the same length as the original input.
""" """
return batch, components return batch, components
@@ -211,6 +218,9 @@ class BaseYArchitecture(pl.LightningModule):
step: str = Steps.TRAINING, step: str = Steps.TRAINING,
**metric_kwargs, **metric_kwargs,
): ):
'''
Register a callback for evaluating a torchmetric.
'''
if step == Steps.PREDICT: if step == Steps.PREDICT:
raise ValueError("Prediction metrics are not supported.") raise ValueError("Prediction metrics are not supported.")
@@ -224,10 +234,10 @@ class BaseYArchitecture(pl.LightningModule):
# Prediction Metrics # Prediction Metrics
preds = self(batch) preds = self(batch)
x, y = batch _, y = batch
for metric in self.registered_metrics[step]: for metric in self.registered_metrics[step]:
instance = self.registered_metrics[step][metric].to(self.device) instance = self.registered_metrics[step][metric].to(self.device)
instance(y, preds) instance(y, preds.reshape(y.shape))
def update_metrics_epoch(self, step): def update_metrics_epoch(self, step):
for metric in self.registered_metrics[step]: for metric in self.registered_metrics[step]:
@@ -247,7 +257,7 @@ class BaseYArchitecture(pl.LightningModule):
return self.loss_forward(batch) return self.loss_forward(batch)
def training_epoch_end(self, outs) -> None: def training_epoch_end(self, outputs) -> None:
self.update_metrics_epoch(Steps.TRAINING) self.update_metrics_epoch(Steps.TRAINING)
# >>>> Validation # >>>> Validation
@@ -256,7 +266,7 @@ class BaseYArchitecture(pl.LightningModule):
return self.loss_forward(batch) return self.loss_forward(batch)
def validation_epoch_end(self, outs) -> None: def validation_epoch_end(self, outputs) -> None:
self.update_metrics_epoch(Steps.VALIDATION) self.update_metrics_epoch(Steps.VALIDATION)
# >>>> Test # >>>> Test
@@ -264,7 +274,7 @@ class BaseYArchitecture(pl.LightningModule):
self.update_metrics_step(batch, Steps.TEST) self.update_metrics_step(batch, Steps.TEST)
return self.loss_forward(batch) return self.loss_forward(batch)
def test_epoch_end(self, outs) -> None: def test_epoch_end(self, outputs) -> None:
self.update_metrics_epoch(Steps.TEST) self.update_metrics_epoch(Steps.TEST)
# >>>> Prediction # >>>> Prediction

View File

@@ -32,9 +32,9 @@ class SimpleComparisonMixin(BaseYArchitecture):
comparison_args: Keyword arguments for the comparison function. Default: {}. comparison_args: Keyword arguments for the comparison function. Default: {}.
""" """
comparison_fn: Callable = euclidean_distance comparison_fn: Callable = euclidean_distance
comparison_args: dict = field(default_factory=lambda: dict()) comparison_args: dict = field(default_factory=dict)
comparison_parameters: dict = field(default_factory=lambda: dict()) comparison_parameters: dict = field(default_factory=dict)
# Steps # Steps
# ---------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------
@@ -44,7 +44,7 @@ class SimpleComparisonMixin(BaseYArchitecture):
**hparams.comparison_args, **hparams.comparison_args,
) )
self.comparison_kwargs: dict[str, Tensor] = dict() self.comparison_kwargs: dict[str, Tensor] = {}
def comparison(self, batch, components): def comparison(self, batch, components):
comp_tensor, _ = components comp_tensor, _ = components
@@ -86,7 +86,7 @@ class OmegaComparisonMixin(SimpleComparisonMixin):
latent_dim: int = 2 latent_dim: int = 2
omega_initializer: type[ omega_initializer: type[
AbstractLinearTransformInitializer] = EyeLinearTransformInitializer AbstractLinearTransformInitializer] = EyeLinearTransformInitializer
omega_initializer_kwargs: dict = field(default_factory=lambda: dict()) omega_initializer_kwargs: dict = field(default_factory=dict)
# Steps # Steps
# ---------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------
@@ -137,3 +137,12 @@ class OmegaComparisonMixin(SimpleComparisonMixin):
''' '''
lam = self.lambda_matrix lam = self.lambda_matrix
return lam.abs().sum(0) return lam.abs().sum(0)
@property
def parameter_omega(self):
return self._omega
@parameter_omega.setter
def parameter_omega(self, new_omega):
with torch.no_grad():
self._omega.data.copy_(new_omega)

View File

@@ -46,7 +46,7 @@ class MultipleLearningRateMixin(BaseYArchitecture):
lr: The learning rate. Default: 0.1. lr: The learning rate. Default: 0.1.
optimizer: The optimizer to use. Default: torch.optim.Adam. optimizer: The optimizer to use. Default: torch.optim.Adam.
""" """
lr: dict = field(default_factory=lambda: dict()) lr: dict = field(default_factory=dict)
optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam
# Hooks # Hooks

View File

@@ -1,13 +1,15 @@
import logging import logging
import warnings import warnings
from enum import Enum
from typing import Optional, Type from typing import Optional, Type
import matplotlib.pyplot as plt
import numpy as np import numpy as np
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
import torchmetrics import torchmetrics
from matplotlib import pyplot as plt
from prototorch.models.architectures.base import BaseYArchitecture, Steps from prototorch.models.architectures.base import BaseYArchitecture, Steps
from prototorch.models.architectures.comparison import OmegaComparisonMixin
from prototorch.models.library.gmlvq import GMLVQ from prototorch.models.library.gmlvq import GMLVQ
from prototorch.models.vis import Vis2DAbstract from prototorch.models.vis import Vis2DAbstract
from prototorch.utils.utils import mesh2d from prototorch.utils.utils import mesh2d
@@ -36,12 +38,14 @@ class LogTorchmetricCallback(pl.Callback):
name, name,
metric: Type[torchmetrics.Metric], metric: Type[torchmetrics.Metric],
step: str = Steps.TRAINING, step: str = Steps.TRAINING,
on_epoch=True,
**metric_kwargs, **metric_kwargs,
) -> None: ) -> None:
self.name = name self.name = name
self.metric = metric self.metric = metric
self.metric_kwargs = metric_kwargs self.metric_kwargs = metric_kwargs
self.step = step self.step = step
self.on_epoch = on_epoch
def setup( def setup(
self, self,
@@ -57,7 +61,12 @@ class LogTorchmetricCallback(pl.Callback):
) )
def __call__(self, value, pl_module: BaseYArchitecture): def __call__(self, value, pl_module: BaseYArchitecture):
pl_module.log(self.name, value) pl_module.log(
self.name,
value,
on_epoch=self.on_epoch,
on_step=(not self.on_epoch),
)
class LogConfusionMatrix(LogTorchmetricCallback): class LogConfusionMatrix(LogTorchmetricCallback):
@@ -207,7 +216,7 @@ class PlotLambdaMatrixToTensorboard(pl.Callback):
# add to tensorboard # add to tensorboard
if isinstance(trainer.logger, TensorBoardLogger): if isinstance(trainer.logger, TensorBoardLogger):
trainer.logger.experiment.add_figure( trainer.logger.experiment.add_figure(
f"lambda_matrix", "lambda_matrix",
self.fig, self.fig,
trainer.global_step, trainer.global_step,
) )
@@ -215,3 +224,84 @@ class PlotLambdaMatrixToTensorboard(pl.Callback):
warnings.warn( warnings.warn(
f"{self.__class__.__name__} is not compatible with {trainer.logger.__class__.__name__} as logger. Use TensorBoardLogger instead." f"{self.__class__.__name__} is not compatible with {trainer.logger.__class__.__name__} as logger. Use TensorBoardLogger instead."
) )
class Profiles(Enum):
'''
Available Profiles
'''
RELEVANCE = 'relevance'
INFLUENCE = 'influence'
def __str__(self):
return str(self.value)
class PlotMatrixProfiles(pl.Callback):
def __init__(self, profile=Profiles.INFLUENCE, cmap='seismic') -> None:
super().__init__()
self.cmap = cmap
self.profile = profile
def on_train_start(self, trainer, pl_module: GMLVQ):
'''
Plot initial profile.
'''
self._plot_profile(trainer, pl_module)
def on_train_epoch_end(self, trainer, pl_module: GMLVQ):
'''
Plot after every epoch.
'''
self._plot_profile(trainer, pl_module)
def _plot_profile(self, trainer, pl_module: GMLVQ):
fig, ax = plt.subplots(1, 1)
# plot lambda matrix
l_matrix = torch.abs(pl_module.lambda_matrix)
if self.profile == Profiles.RELEVANCE:
profile_value = l_matrix.diag()
elif self.profile == Profiles.INFLUENCE:
profile_value = l_matrix.sum(0)
# plot lambda matrix
ax.plot(profile_value.detach().numpy())
# add title
ax.set_title(f'{self.profile} profile')
# add to tensorboard
if isinstance(trainer.logger, TensorBoardLogger):
trainer.logger.experiment.add_figure(
f"{self.profile}_matrix",
fig,
trainer.global_step,
)
else:
class_name = self.__class__.__name__
logger_name = trainer.logger.__class__.__name__
warnings.warn(
f"{class_name} is not compatible with {logger_name} as logger. Use TensorBoardLogger instead."
)
class OmegaTraceNormalization(pl.Callback):
'''
Trace normalization of the Omega Matrix.
'''
__epsilon = torch.finfo(torch.float32).eps
def on_train_epoch_end(self, trainer: "pl.Trainer",
pl_module: OmegaComparisonMixin) -> None:
omega = pl_module.parameter_omega
denominator = torch.sqrt(torch.trace(omega.T @ omega))
logging.debug(
"Apply Omega Trace Normalization: demoninator=%f",
denominator.item(),
)
pl_module.parameter_omega = omega / (denominator + self.__epsilon)

View File

@@ -41,7 +41,7 @@ class GMLVQ(
comparison_args: Keyword arguments for the comparison function. Override Default: {}. comparison_args: Keyword arguments for the comparison function. Override Default: {}.
""" """
comparison_fn: Callable = omega_distance comparison_fn: Callable = omega_distance
comparison_args: dict = field(default_factory=lambda: dict()) comparison_args: dict = field(default_factory=dict)
optimizer: type[torch.optim.Optimizer] = torch.optim.Adam optimizer: type[torch.optim.Optimizer] = torch.optim.Adam
lr: dict = field(default_factory=lambda: dict( lr: dict = field(default_factory=lambda: dict(

View File

@@ -10,6 +10,8 @@
ProtoTorch models Plugin Package ProtoTorch models Plugin Package
""" """
from pathlib import Path
from pkg_resources import safe_name from pkg_resources import safe_name
from setuptools import find_namespace_packages, setup from setuptools import find_namespace_packages, setup
@@ -18,8 +20,7 @@ PLUGIN_NAME = "models"
PROJECT_URL = "https://github.com/si-cim/prototorch_models" PROJECT_URL = "https://github.com/si-cim/prototorch_models"
DOWNLOAD_URL = "https://github.com/si-cim/prototorch_models.git" DOWNLOAD_URL = "https://github.com/si-cim/prototorch_models.git"
with open("README.md", "r") as fh: long_description = Path("README.md").read_text(encoding='utf8')
long_description = fh.read()
INSTALL_REQUIRES = [ INSTALL_REQUIRES = [
"prototorch>=0.7.3", "prototorch>=0.7.3",
@@ -55,7 +56,7 @@ ALL = CLI + DEV + DOCS + EXAMPLES + TESTS
setup( setup(
name=safe_name("prototorch_" + PLUGIN_NAME), name=safe_name("prototorch_" + PLUGIN_NAME),
version="1.0.0-a6", version="1.0.0-a8",
description="Pre-packaged prototype-based " description="Pre-packaged prototype-based "
"machine learning models using ProtoTorch and PyTorch-Lightning.", "machine learning models using ProtoTorch and PyTorch-Lightning.",
long_description=long_description, long_description=long_description,

13
tests/test_models.py Normal file
View File

@@ -0,0 +1,13 @@
"""prototorch.models test suite."""
import prototorch as pt
from prototorch.models.library import GLVQ
def test_glvq_model_build():
hparams = GLVQ.HyperParameters(
distribution=dict(num_classes=2, per_class=1),
component_initializer=pt.initializers.RNCI(2),
)
model = GLVQ(hparams=hparams)