11 Commits

Author SHA1 Message Date
Alexander Engelsberger
9bb2e20dce build: bump version 1.0.0a7 → 1.0.0a8 2022-10-26 14:53:52 +02:00
Alexander Engelsberger
6748951b63 ci: temporarily remove 3.11 2022-10-26 13:31:52 +02:00
Alexander Engelsberger
c547af728b ci: add refurb to pre-commit config 2022-10-26 13:19:45 +02:00
Alexander Engelsberger
482044ec87 ci: update pre-commit configuration 2022-10-26 13:03:15 +02:00
Alexander Engelsberger
45f01f39d4 ci: add python 3.11 to ci 2022-10-26 12:58:05 +02:00
Alexander Engelsberger
9ab864fbdf chore: add simple test to fix github action 2022-10-26 12:57:45 +02:00
Alexander Engelsberger
365e0fb931 feat: add useful callbacks for GMLVQ
omega trace normalization and matrix profile visualization
2022-09-21 13:23:43 +02:00
Alexander Engelsberger
ba50dfba50 fix: accuracy as torchmetric fixed 2022-09-21 10:22:35 +02:00
Alexander Engelsberger
16ca409f07 feat: metric callback defaults on epoch 2022-08-26 10:58:33 +02:00
Alexander Engelsberger
c3cad19853 build: bump version 1.0.0a6 → 1.0.0a7 2022-08-19 12:17:32 +02:00
Alexander Engelsberger
ec294bdd37 feat: add omega parameter api 2022-08-19 12:15:11 +02:00
13 changed files with 174 additions and 37 deletions

View File

@@ -1,5 +1,5 @@
[bumpversion]
current_version = 1.0.0a6
current_version = 1.0.0a8
commit = True
tag = True
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)((?P<release>[a-zA-Z0-9_.-]+))?

View File

@@ -21,7 +21,7 @@ jobs:
run: |
python -m pip install --upgrade pip
pip install .[all]
- uses: pre-commit/action@v2.0.3
- uses: pre-commit/action@v3.0.0
compatibility:
needs: style
strategy:
@@ -36,7 +36,8 @@ jobs:
python-version: "3.8"
- os: windows-latest
python-version: "3.9"
- os: windows-latest
python-version: "3.11"
runs-on: ${{ matrix.os }}
steps:
- uses: actions/checkout@v2

View File

@@ -3,7 +3,7 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.2.0
rev: v4.3.0
hooks:
- id: trailing-whitespace
exclude: (^\.bumpversion\.cfg$|cli_messages\.py)
@@ -14,7 +14,7 @@ repos:
- id: check-case-conflict
- repo: https://github.com/myint/autoflake
rev: v1.4
rev: v1.7.7
hooks:
- id: autoflake
@@ -24,7 +24,7 @@ repos:
- id: isort
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v0.950
rev: v0.982
hooks:
- id: mypy
files: prototorch
@@ -43,7 +43,7 @@ repos:
- id: python-check-blanket-noqa
- repo: https://github.com/asottile/pyupgrade
rev: v2.32.1
rev: v3.1.0
hooks:
- id: pyupgrade
@@ -52,3 +52,8 @@ repos:
hooks:
- id: gitlint
args: [--contrib=CT1, --ignore=B6, --msg-filename]
- repo: https://github.com/dosisod/refurb
rev: v1.4.0
hooks:
- id: refurb

View File

@@ -23,7 +23,7 @@ author = "Jensun Ravichandran"
# The full version, including alpha/beta/rc tags
#
release = "1.0.0-a6"
release = "1.0.0-a8"
# -- General configuration ---------------------------------------------------

View File

@@ -97,6 +97,13 @@ def main():
step=Steps.VALIDATION,
)
accuracy = LogTorchmetricCallback(
'validation_accuracy',
torchmetrics.Accuracy,
num_classes=3,
step=Steps.VALIDATION,
)
es = EarlyStopping(
monitor=stopping_criterion.name,
mode="max",
@@ -111,6 +118,7 @@ def main():
callbacks=[
vis,
recall,
accuracy,
stopping_criterion,
es,
PlotLambdaMatrixToTensorboard(),

View File

@@ -22,4 +22,4 @@ __all__ = [
"GLVQLossMixin",
]
__version__ = "1.0.0-a6"
__version__ = "1.0.0-a8"

View File

@@ -46,15 +46,15 @@ class BaseYArchitecture(pl.LightningModule):
components_layer: torch.nn.Module
def __init__(self, hparams) -> None:
if type(hparams) is dict:
if isinstance(hparams, dict):
self.save_hyperparameters(hparams)
# TODO: => Move into Component Child
del hparams["initialized_proto_shape"]
hparams = self.HyperParameters(**hparams)
else:
hparam_dict = asdict(hparams)
hparam_dict["component_initializer"] = None
self.save_hyperparameters(hparam_dict, )
hparams_dict = asdict(hparams)
hparams_dict["component_initializer"] = None
self.save_hyperparameters(hparams_dict, )
super().__init__()
@@ -72,6 +72,9 @@ class BaseYArchitecture(pl.LightningModule):
# external API
def get_competition(self, batch, components):
'''
Returns the output of the competition layer.
'''
latent_batch, latent_components = self.backbone(batch, components)
# TODO: => Latent Hook
comparison_tensor = self.comparison(latent_batch, latent_components)
@@ -79,6 +82,9 @@ class BaseYArchitecture(pl.LightningModule):
return comparison_tensor
def forward(self, batch):
'''
Returns the prediction.
'''
if isinstance(batch, torch.Tensor):
batch = (batch, None)
# TODO: manage different datatypes?
@@ -95,6 +101,9 @@ class BaseYArchitecture(pl.LightningModule):
return self.forward(batch)
def forward_comparison(self, batch):
'''
Returns the Output of the comparison layer.
'''
if isinstance(batch, torch.Tensor):
batch = (batch, None)
# TODO: manage different datatypes?
@@ -103,6 +112,9 @@ class BaseYArchitecture(pl.LightningModule):
return self.get_competition(batch, components)
def loss_forward(self, batch):
'''
Returns the output of the loss layer.
'''
# TODO: manage different datatypes?
components = self.components_layer()
# TODO: => Component Hook
@@ -115,37 +127,31 @@ class BaseYArchitecture(pl.LightningModule):
"""
All initialization necessary for the components step.
"""
...
def init_backbone(self, hparams: HyperParameters) -> None:
"""
All initialization necessary for the backbone step.
"""
...
def init_comparison(self, hparams: HyperParameters) -> None:
"""
All initialization necessary for the comparison step.
"""
...
def init_competition(self, hparams: HyperParameters) -> None:
"""
All initialization necessary for the competition step.
"""
...
def init_loss(self, hparams: HyperParameters) -> None:
"""
All initialization necessary for the loss step.
"""
...
def init_inference(self, hparams: HyperParameters) -> None:
"""
All initialization necessary for the inference step.
"""
...
# Empty Steps
def components(self):
@@ -162,7 +168,8 @@ class BaseYArchitecture(pl.LightningModule):
The backbone step receives the data batch and the components.
It can transform both by an arbitrary function.
It returns the transformed batch and components, each of the same length as the original input.
It returns the transformed batch and components,
each of the same length as the original input.
"""
return batch, components
@@ -211,6 +218,9 @@ class BaseYArchitecture(pl.LightningModule):
step: str = Steps.TRAINING,
**metric_kwargs,
):
'''
Register a callback for evaluating a torchmetric.
'''
if step == Steps.PREDICT:
raise ValueError("Prediction metrics are not supported.")
@@ -224,10 +234,10 @@ class BaseYArchitecture(pl.LightningModule):
# Prediction Metrics
preds = self(batch)
x, y = batch
_, y = batch
for metric in self.registered_metrics[step]:
instance = self.registered_metrics[step][metric].to(self.device)
instance(y, preds)
instance(y, preds.reshape(y.shape))
def update_metrics_epoch(self, step):
for metric in self.registered_metrics[step]:
@@ -247,7 +257,7 @@ class BaseYArchitecture(pl.LightningModule):
return self.loss_forward(batch)
def training_epoch_end(self, outs) -> None:
def training_epoch_end(self, outputs) -> None:
self.update_metrics_epoch(Steps.TRAINING)
# >>>> Validation
@@ -256,7 +266,7 @@ class BaseYArchitecture(pl.LightningModule):
return self.loss_forward(batch)
def validation_epoch_end(self, outs) -> None:
def validation_epoch_end(self, outputs) -> None:
self.update_metrics_epoch(Steps.VALIDATION)
# >>>> Test
@@ -264,7 +274,7 @@ class BaseYArchitecture(pl.LightningModule):
self.update_metrics_step(batch, Steps.TEST)
return self.loss_forward(batch)
def test_epoch_end(self, outs) -> None:
def test_epoch_end(self, outputs) -> None:
self.update_metrics_epoch(Steps.TEST)
# >>>> Prediction

View File

@@ -32,9 +32,9 @@ class SimpleComparisonMixin(BaseYArchitecture):
comparison_args: Keyword arguments for the comparison function. Default: {}.
"""
comparison_fn: Callable = euclidean_distance
comparison_args: dict = field(default_factory=lambda: dict())
comparison_args: dict = field(default_factory=dict)
comparison_parameters: dict = field(default_factory=lambda: dict())
comparison_parameters: dict = field(default_factory=dict)
# Steps
# ----------------------------------------------------------------------------------------------
@@ -44,7 +44,7 @@ class SimpleComparisonMixin(BaseYArchitecture):
**hparams.comparison_args,
)
self.comparison_kwargs: dict[str, Tensor] = dict()
self.comparison_kwargs: dict[str, Tensor] = {}
def comparison(self, batch, components):
comp_tensor, _ = components
@@ -86,7 +86,7 @@ class OmegaComparisonMixin(SimpleComparisonMixin):
latent_dim: int = 2
omega_initializer: type[
AbstractLinearTransformInitializer] = EyeLinearTransformInitializer
omega_initializer_kwargs: dict = field(default_factory=lambda: dict())
omega_initializer_kwargs: dict = field(default_factory=dict)
# Steps
# ----------------------------------------------------------------------------------------------
@@ -137,3 +137,12 @@ class OmegaComparisonMixin(SimpleComparisonMixin):
'''
lam = self.lambda_matrix
return lam.abs().sum(0)
@property
def parameter_omega(self):
return self._omega
@parameter_omega.setter
def parameter_omega(self, new_omega):
with torch.no_grad():
self._omega.data.copy_(new_omega)

View File

@@ -46,7 +46,7 @@ class MultipleLearningRateMixin(BaseYArchitecture):
lr: The learning rate. Default: 0.1.
optimizer: The optimizer to use. Default: torch.optim.Adam.
"""
lr: dict = field(default_factory=lambda: dict())
lr: dict = field(default_factory=dict)
optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam
# Hooks

View File

@@ -1,13 +1,15 @@
import logging
import warnings
from enum import Enum
from typing import Optional, Type
import matplotlib.pyplot as plt
import numpy as np
import pytorch_lightning as pl
import torch
import torchmetrics
from matplotlib import pyplot as plt
from prototorch.models.architectures.base import BaseYArchitecture, Steps
from prototorch.models.architectures.comparison import OmegaComparisonMixin
from prototorch.models.library.gmlvq import GMLVQ
from prototorch.models.vis import Vis2DAbstract
from prototorch.utils.utils import mesh2d
@@ -36,12 +38,14 @@ class LogTorchmetricCallback(pl.Callback):
name,
metric: Type[torchmetrics.Metric],
step: str = Steps.TRAINING,
on_epoch=True,
**metric_kwargs,
) -> None:
self.name = name
self.metric = metric
self.metric_kwargs = metric_kwargs
self.step = step
self.on_epoch = on_epoch
def setup(
self,
@@ -57,7 +61,12 @@ class LogTorchmetricCallback(pl.Callback):
)
def __call__(self, value, pl_module: BaseYArchitecture):
pl_module.log(self.name, value)
pl_module.log(
self.name,
value,
on_epoch=self.on_epoch,
on_step=(not self.on_epoch),
)
class LogConfusionMatrix(LogTorchmetricCallback):
@@ -207,7 +216,7 @@ class PlotLambdaMatrixToTensorboard(pl.Callback):
# add to tensorboard
if isinstance(trainer.logger, TensorBoardLogger):
trainer.logger.experiment.add_figure(
f"lambda_matrix",
"lambda_matrix",
self.fig,
trainer.global_step,
)
@@ -215,3 +224,84 @@ class PlotLambdaMatrixToTensorboard(pl.Callback):
warnings.warn(
f"{self.__class__.__name__} is not compatible with {trainer.logger.__class__.__name__} as logger. Use TensorBoardLogger instead."
)
class Profiles(Enum):
'''
Available Profiles
'''
RELEVANCE = 'relevance'
INFLUENCE = 'influence'
def __str__(self):
return str(self.value)
class PlotMatrixProfiles(pl.Callback):
def __init__(self, profile=Profiles.INFLUENCE, cmap='seismic') -> None:
super().__init__()
self.cmap = cmap
self.profile = profile
def on_train_start(self, trainer, pl_module: GMLVQ):
'''
Plot initial profile.
'''
self._plot_profile(trainer, pl_module)
def on_train_epoch_end(self, trainer, pl_module: GMLVQ):
'''
Plot after every epoch.
'''
self._plot_profile(trainer, pl_module)
def _plot_profile(self, trainer, pl_module: GMLVQ):
fig, ax = plt.subplots(1, 1)
# plot lambda matrix
l_matrix = torch.abs(pl_module.lambda_matrix)
if self.profile == Profiles.RELEVANCE:
profile_value = l_matrix.diag()
elif self.profile == Profiles.INFLUENCE:
profile_value = l_matrix.sum(0)
# plot lambda matrix
ax.plot(profile_value.detach().numpy())
# add title
ax.set_title(f'{self.profile} profile')
# add to tensorboard
if isinstance(trainer.logger, TensorBoardLogger):
trainer.logger.experiment.add_figure(
f"{self.profile}_matrix",
fig,
trainer.global_step,
)
else:
class_name = self.__class__.__name__
logger_name = trainer.logger.__class__.__name__
warnings.warn(
f"{class_name} is not compatible with {logger_name} as logger. Use TensorBoardLogger instead."
)
class OmegaTraceNormalization(pl.Callback):
'''
Trace normalization of the Omega Matrix.
'''
__epsilon = torch.finfo(torch.float32).eps
def on_train_epoch_end(self, trainer: "pl.Trainer",
pl_module: OmegaComparisonMixin) -> None:
omega = pl_module.parameter_omega
denominator = torch.sqrt(torch.trace(omega.T @ omega))
logging.debug(
"Apply Omega Trace Normalization: demoninator=%f",
denominator.item(),
)
pl_module.parameter_omega = omega / (denominator + self.__epsilon)

View File

@@ -41,7 +41,7 @@ class GMLVQ(
comparison_args: Keyword arguments for the comparison function. Override Default: {}.
"""
comparison_fn: Callable = omega_distance
comparison_args: dict = field(default_factory=lambda: dict())
comparison_args: dict = field(default_factory=dict)
optimizer: type[torch.optim.Optimizer] = torch.optim.Adam
lr: dict = field(default_factory=lambda: dict(

View File

@@ -10,6 +10,8 @@
ProtoTorch models Plugin Package
"""
from pathlib import Path
from pkg_resources import safe_name
from setuptools import find_namespace_packages, setup
@@ -18,8 +20,7 @@ PLUGIN_NAME = "models"
PROJECT_URL = "https://github.com/si-cim/prototorch_models"
DOWNLOAD_URL = "https://github.com/si-cim/prototorch_models.git"
with open("README.md", "r") as fh:
long_description = fh.read()
long_description = Path("README.md").read_text(encoding='utf8')
INSTALL_REQUIRES = [
"prototorch>=0.7.3",
@@ -55,7 +56,7 @@ ALL = CLI + DEV + DOCS + EXAMPLES + TESTS
setup(
name=safe_name("prototorch_" + PLUGIN_NAME),
version="1.0.0-a6",
version="1.0.0-a8",
description="Pre-packaged prototype-based "
"machine learning models using ProtoTorch and PyTorch-Lightning.",
long_description=long_description,

13
tests/test_models.py Normal file
View File

@@ -0,0 +1,13 @@
"""prototorch.models test suite."""
import prototorch as pt
from prototorch.models.library import GLVQ
def test_glvq_model_build():
hparams = GLVQ.HyperParameters(
distribution=dict(num_classes=2, per_class=1),
component_initializer=pt.initializers.RNCI(2),
)
model = GLVQ(hparams=hparams)