Compare commits
11 Commits
Author | SHA1 | Date | |
---|---|---|---|
|
9bb2e20dce | ||
|
6748951b63 | ||
|
c547af728b | ||
|
482044ec87 | ||
|
45f01f39d4 | ||
|
9ab864fbdf | ||
|
365e0fb931 | ||
|
ba50dfba50 | ||
|
16ca409f07 | ||
|
c3cad19853 | ||
|
ec294bdd37 |
@@ -1,5 +1,5 @@
|
||||
[bumpversion]
|
||||
current_version = 1.0.0a6
|
||||
current_version = 1.0.0a8
|
||||
commit = True
|
||||
tag = True
|
||||
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)((?P<release>[a-zA-Z0-9_.-]+))?
|
||||
|
5
.github/workflows/pythonapp.yml
vendored
5
.github/workflows/pythonapp.yml
vendored
@@ -21,7 +21,7 @@ jobs:
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install .[all]
|
||||
- uses: pre-commit/action@v2.0.3
|
||||
- uses: pre-commit/action@v3.0.0
|
||||
compatibility:
|
||||
needs: style
|
||||
strategy:
|
||||
@@ -36,7 +36,8 @@ jobs:
|
||||
python-version: "3.8"
|
||||
- os: windows-latest
|
||||
python-version: "3.9"
|
||||
|
||||
- os: windows-latest
|
||||
python-version: "3.11"
|
||||
runs-on: ${{ matrix.os }}
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
|
@@ -3,7 +3,7 @@
|
||||
|
||||
repos:
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v4.2.0
|
||||
rev: v4.3.0
|
||||
hooks:
|
||||
- id: trailing-whitespace
|
||||
exclude: (^\.bumpversion\.cfg$|cli_messages\.py)
|
||||
@@ -14,7 +14,7 @@ repos:
|
||||
- id: check-case-conflict
|
||||
|
||||
- repo: https://github.com/myint/autoflake
|
||||
rev: v1.4
|
||||
rev: v1.7.7
|
||||
hooks:
|
||||
- id: autoflake
|
||||
|
||||
@@ -24,7 +24,7 @@ repos:
|
||||
- id: isort
|
||||
|
||||
- repo: https://github.com/pre-commit/mirrors-mypy
|
||||
rev: v0.950
|
||||
rev: v0.982
|
||||
hooks:
|
||||
- id: mypy
|
||||
files: prototorch
|
||||
@@ -43,7 +43,7 @@ repos:
|
||||
- id: python-check-blanket-noqa
|
||||
|
||||
- repo: https://github.com/asottile/pyupgrade
|
||||
rev: v2.32.1
|
||||
rev: v3.1.0
|
||||
hooks:
|
||||
- id: pyupgrade
|
||||
|
||||
@@ -52,3 +52,8 @@ repos:
|
||||
hooks:
|
||||
- id: gitlint
|
||||
args: [--contrib=CT1, --ignore=B6, --msg-filename]
|
||||
|
||||
- repo: https://github.com/dosisod/refurb
|
||||
rev: v1.4.0
|
||||
hooks:
|
||||
- id: refurb
|
||||
|
@@ -23,7 +23,7 @@ author = "Jensun Ravichandran"
|
||||
|
||||
# The full version, including alpha/beta/rc tags
|
||||
#
|
||||
release = "1.0.0-a6"
|
||||
release = "1.0.0-a8"
|
||||
|
||||
# -- General configuration ---------------------------------------------------
|
||||
|
||||
|
@@ -97,6 +97,13 @@ def main():
|
||||
step=Steps.VALIDATION,
|
||||
)
|
||||
|
||||
accuracy = LogTorchmetricCallback(
|
||||
'validation_accuracy',
|
||||
torchmetrics.Accuracy,
|
||||
num_classes=3,
|
||||
step=Steps.VALIDATION,
|
||||
)
|
||||
|
||||
es = EarlyStopping(
|
||||
monitor=stopping_criterion.name,
|
||||
mode="max",
|
||||
@@ -111,6 +118,7 @@ def main():
|
||||
callbacks=[
|
||||
vis,
|
||||
recall,
|
||||
accuracy,
|
||||
stopping_criterion,
|
||||
es,
|
||||
PlotLambdaMatrixToTensorboard(),
|
||||
|
@@ -22,4 +22,4 @@ __all__ = [
|
||||
"GLVQLossMixin",
|
||||
]
|
||||
|
||||
__version__ = "1.0.0-a6"
|
||||
__version__ = "1.0.0-a8"
|
||||
|
@@ -46,15 +46,15 @@ class BaseYArchitecture(pl.LightningModule):
|
||||
components_layer: torch.nn.Module
|
||||
|
||||
def __init__(self, hparams) -> None:
|
||||
if type(hparams) is dict:
|
||||
if isinstance(hparams, dict):
|
||||
self.save_hyperparameters(hparams)
|
||||
# TODO: => Move into Component Child
|
||||
del hparams["initialized_proto_shape"]
|
||||
hparams = self.HyperParameters(**hparams)
|
||||
else:
|
||||
hparam_dict = asdict(hparams)
|
||||
hparam_dict["component_initializer"] = None
|
||||
self.save_hyperparameters(hparam_dict, )
|
||||
hparams_dict = asdict(hparams)
|
||||
hparams_dict["component_initializer"] = None
|
||||
self.save_hyperparameters(hparams_dict, )
|
||||
|
||||
super().__init__()
|
||||
|
||||
@@ -72,6 +72,9 @@ class BaseYArchitecture(pl.LightningModule):
|
||||
|
||||
# external API
|
||||
def get_competition(self, batch, components):
|
||||
'''
|
||||
Returns the output of the competition layer.
|
||||
'''
|
||||
latent_batch, latent_components = self.backbone(batch, components)
|
||||
# TODO: => Latent Hook
|
||||
comparison_tensor = self.comparison(latent_batch, latent_components)
|
||||
@@ -79,6 +82,9 @@ class BaseYArchitecture(pl.LightningModule):
|
||||
return comparison_tensor
|
||||
|
||||
def forward(self, batch):
|
||||
'''
|
||||
Returns the prediction.
|
||||
'''
|
||||
if isinstance(batch, torch.Tensor):
|
||||
batch = (batch, None)
|
||||
# TODO: manage different datatypes?
|
||||
@@ -95,6 +101,9 @@ class BaseYArchitecture(pl.LightningModule):
|
||||
return self.forward(batch)
|
||||
|
||||
def forward_comparison(self, batch):
|
||||
'''
|
||||
Returns the Output of the comparison layer.
|
||||
'''
|
||||
if isinstance(batch, torch.Tensor):
|
||||
batch = (batch, None)
|
||||
# TODO: manage different datatypes?
|
||||
@@ -103,6 +112,9 @@ class BaseYArchitecture(pl.LightningModule):
|
||||
return self.get_competition(batch, components)
|
||||
|
||||
def loss_forward(self, batch):
|
||||
'''
|
||||
Returns the output of the loss layer.
|
||||
'''
|
||||
# TODO: manage different datatypes?
|
||||
components = self.components_layer()
|
||||
# TODO: => Component Hook
|
||||
@@ -115,37 +127,31 @@ class BaseYArchitecture(pl.LightningModule):
|
||||
"""
|
||||
All initialization necessary for the components step.
|
||||
"""
|
||||
...
|
||||
|
||||
def init_backbone(self, hparams: HyperParameters) -> None:
|
||||
"""
|
||||
All initialization necessary for the backbone step.
|
||||
"""
|
||||
...
|
||||
|
||||
def init_comparison(self, hparams: HyperParameters) -> None:
|
||||
"""
|
||||
All initialization necessary for the comparison step.
|
||||
"""
|
||||
...
|
||||
|
||||
def init_competition(self, hparams: HyperParameters) -> None:
|
||||
"""
|
||||
All initialization necessary for the competition step.
|
||||
"""
|
||||
...
|
||||
|
||||
def init_loss(self, hparams: HyperParameters) -> None:
|
||||
"""
|
||||
All initialization necessary for the loss step.
|
||||
"""
|
||||
...
|
||||
|
||||
def init_inference(self, hparams: HyperParameters) -> None:
|
||||
"""
|
||||
All initialization necessary for the inference step.
|
||||
"""
|
||||
...
|
||||
|
||||
# Empty Steps
|
||||
def components(self):
|
||||
@@ -162,7 +168,8 @@ class BaseYArchitecture(pl.LightningModule):
|
||||
The backbone step receives the data batch and the components.
|
||||
It can transform both by an arbitrary function.
|
||||
|
||||
It returns the transformed batch and components, each of the same length as the original input.
|
||||
It returns the transformed batch and components,
|
||||
each of the same length as the original input.
|
||||
"""
|
||||
return batch, components
|
||||
|
||||
@@ -211,6 +218,9 @@ class BaseYArchitecture(pl.LightningModule):
|
||||
step: str = Steps.TRAINING,
|
||||
**metric_kwargs,
|
||||
):
|
||||
'''
|
||||
Register a callback for evaluating a torchmetric.
|
||||
'''
|
||||
if step == Steps.PREDICT:
|
||||
raise ValueError("Prediction metrics are not supported.")
|
||||
|
||||
@@ -224,10 +234,10 @@ class BaseYArchitecture(pl.LightningModule):
|
||||
# Prediction Metrics
|
||||
preds = self(batch)
|
||||
|
||||
x, y = batch
|
||||
_, y = batch
|
||||
for metric in self.registered_metrics[step]:
|
||||
instance = self.registered_metrics[step][metric].to(self.device)
|
||||
instance(y, preds)
|
||||
instance(y, preds.reshape(y.shape))
|
||||
|
||||
def update_metrics_epoch(self, step):
|
||||
for metric in self.registered_metrics[step]:
|
||||
@@ -247,7 +257,7 @@ class BaseYArchitecture(pl.LightningModule):
|
||||
|
||||
return self.loss_forward(batch)
|
||||
|
||||
def training_epoch_end(self, outs) -> None:
|
||||
def training_epoch_end(self, outputs) -> None:
|
||||
self.update_metrics_epoch(Steps.TRAINING)
|
||||
|
||||
# >>>> Validation
|
||||
@@ -256,7 +266,7 @@ class BaseYArchitecture(pl.LightningModule):
|
||||
|
||||
return self.loss_forward(batch)
|
||||
|
||||
def validation_epoch_end(self, outs) -> None:
|
||||
def validation_epoch_end(self, outputs) -> None:
|
||||
self.update_metrics_epoch(Steps.VALIDATION)
|
||||
|
||||
# >>>> Test
|
||||
@@ -264,7 +274,7 @@ class BaseYArchitecture(pl.LightningModule):
|
||||
self.update_metrics_step(batch, Steps.TEST)
|
||||
return self.loss_forward(batch)
|
||||
|
||||
def test_epoch_end(self, outs) -> None:
|
||||
def test_epoch_end(self, outputs) -> None:
|
||||
self.update_metrics_epoch(Steps.TEST)
|
||||
|
||||
# >>>> Prediction
|
||||
|
@@ -32,9 +32,9 @@ class SimpleComparisonMixin(BaseYArchitecture):
|
||||
comparison_args: Keyword arguments for the comparison function. Default: {}.
|
||||
"""
|
||||
comparison_fn: Callable = euclidean_distance
|
||||
comparison_args: dict = field(default_factory=lambda: dict())
|
||||
comparison_args: dict = field(default_factory=dict)
|
||||
|
||||
comparison_parameters: dict = field(default_factory=lambda: dict())
|
||||
comparison_parameters: dict = field(default_factory=dict)
|
||||
|
||||
# Steps
|
||||
# ----------------------------------------------------------------------------------------------
|
||||
@@ -44,7 +44,7 @@ class SimpleComparisonMixin(BaseYArchitecture):
|
||||
**hparams.comparison_args,
|
||||
)
|
||||
|
||||
self.comparison_kwargs: dict[str, Tensor] = dict()
|
||||
self.comparison_kwargs: dict[str, Tensor] = {}
|
||||
|
||||
def comparison(self, batch, components):
|
||||
comp_tensor, _ = components
|
||||
@@ -86,7 +86,7 @@ class OmegaComparisonMixin(SimpleComparisonMixin):
|
||||
latent_dim: int = 2
|
||||
omega_initializer: type[
|
||||
AbstractLinearTransformInitializer] = EyeLinearTransformInitializer
|
||||
omega_initializer_kwargs: dict = field(default_factory=lambda: dict())
|
||||
omega_initializer_kwargs: dict = field(default_factory=dict)
|
||||
|
||||
# Steps
|
||||
# ----------------------------------------------------------------------------------------------
|
||||
@@ -137,3 +137,12 @@ class OmegaComparisonMixin(SimpleComparisonMixin):
|
||||
'''
|
||||
lam = self.lambda_matrix
|
||||
return lam.abs().sum(0)
|
||||
|
||||
@property
|
||||
def parameter_omega(self):
|
||||
return self._omega
|
||||
|
||||
@parameter_omega.setter
|
||||
def parameter_omega(self, new_omega):
|
||||
with torch.no_grad():
|
||||
self._omega.data.copy_(new_omega)
|
||||
|
@@ -46,7 +46,7 @@ class MultipleLearningRateMixin(BaseYArchitecture):
|
||||
lr: The learning rate. Default: 0.1.
|
||||
optimizer: The optimizer to use. Default: torch.optim.Adam.
|
||||
"""
|
||||
lr: dict = field(default_factory=lambda: dict())
|
||||
lr: dict = field(default_factory=dict)
|
||||
optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam
|
||||
|
||||
# Hooks
|
||||
|
@@ -1,13 +1,15 @@
|
||||
import logging
|
||||
import warnings
|
||||
from enum import Enum
|
||||
from typing import Optional, Type
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
import torchmetrics
|
||||
from matplotlib import pyplot as plt
|
||||
from prototorch.models.architectures.base import BaseYArchitecture, Steps
|
||||
from prototorch.models.architectures.comparison import OmegaComparisonMixin
|
||||
from prototorch.models.library.gmlvq import GMLVQ
|
||||
from prototorch.models.vis import Vis2DAbstract
|
||||
from prototorch.utils.utils import mesh2d
|
||||
@@ -36,12 +38,14 @@ class LogTorchmetricCallback(pl.Callback):
|
||||
name,
|
||||
metric: Type[torchmetrics.Metric],
|
||||
step: str = Steps.TRAINING,
|
||||
on_epoch=True,
|
||||
**metric_kwargs,
|
||||
) -> None:
|
||||
self.name = name
|
||||
self.metric = metric
|
||||
self.metric_kwargs = metric_kwargs
|
||||
self.step = step
|
||||
self.on_epoch = on_epoch
|
||||
|
||||
def setup(
|
||||
self,
|
||||
@@ -57,7 +61,12 @@ class LogTorchmetricCallback(pl.Callback):
|
||||
)
|
||||
|
||||
def __call__(self, value, pl_module: BaseYArchitecture):
|
||||
pl_module.log(self.name, value)
|
||||
pl_module.log(
|
||||
self.name,
|
||||
value,
|
||||
on_epoch=self.on_epoch,
|
||||
on_step=(not self.on_epoch),
|
||||
)
|
||||
|
||||
|
||||
class LogConfusionMatrix(LogTorchmetricCallback):
|
||||
@@ -207,7 +216,7 @@ class PlotLambdaMatrixToTensorboard(pl.Callback):
|
||||
# add to tensorboard
|
||||
if isinstance(trainer.logger, TensorBoardLogger):
|
||||
trainer.logger.experiment.add_figure(
|
||||
f"lambda_matrix",
|
||||
"lambda_matrix",
|
||||
self.fig,
|
||||
trainer.global_step,
|
||||
)
|
||||
@@ -215,3 +224,84 @@ class PlotLambdaMatrixToTensorboard(pl.Callback):
|
||||
warnings.warn(
|
||||
f"{self.__class__.__name__} is not compatible with {trainer.logger.__class__.__name__} as logger. Use TensorBoardLogger instead."
|
||||
)
|
||||
|
||||
|
||||
class Profiles(Enum):
|
||||
'''
|
||||
Available Profiles
|
||||
'''
|
||||
RELEVANCE = 'relevance'
|
||||
INFLUENCE = 'influence'
|
||||
|
||||
def __str__(self):
|
||||
return str(self.value)
|
||||
|
||||
|
||||
class PlotMatrixProfiles(pl.Callback):
|
||||
|
||||
def __init__(self, profile=Profiles.INFLUENCE, cmap='seismic') -> None:
|
||||
super().__init__()
|
||||
self.cmap = cmap
|
||||
self.profile = profile
|
||||
|
||||
def on_train_start(self, trainer, pl_module: GMLVQ):
|
||||
'''
|
||||
Plot initial profile.
|
||||
'''
|
||||
self._plot_profile(trainer, pl_module)
|
||||
|
||||
def on_train_epoch_end(self, trainer, pl_module: GMLVQ):
|
||||
'''
|
||||
Plot after every epoch.
|
||||
'''
|
||||
self._plot_profile(trainer, pl_module)
|
||||
|
||||
def _plot_profile(self, trainer, pl_module: GMLVQ):
|
||||
|
||||
fig, ax = plt.subplots(1, 1)
|
||||
|
||||
# plot lambda matrix
|
||||
l_matrix = torch.abs(pl_module.lambda_matrix)
|
||||
|
||||
if self.profile == Profiles.RELEVANCE:
|
||||
profile_value = l_matrix.diag()
|
||||
elif self.profile == Profiles.INFLUENCE:
|
||||
profile_value = l_matrix.sum(0)
|
||||
|
||||
# plot lambda matrix
|
||||
ax.plot(profile_value.detach().numpy())
|
||||
|
||||
# add title
|
||||
ax.set_title(f'{self.profile} profile')
|
||||
|
||||
# add to tensorboard
|
||||
if isinstance(trainer.logger, TensorBoardLogger):
|
||||
trainer.logger.experiment.add_figure(
|
||||
f"{self.profile}_matrix",
|
||||
fig,
|
||||
trainer.global_step,
|
||||
)
|
||||
else:
|
||||
class_name = self.__class__.__name__
|
||||
logger_name = trainer.logger.__class__.__name__
|
||||
warnings.warn(
|
||||
f"{class_name} is not compatible with {logger_name} as logger. Use TensorBoardLogger instead."
|
||||
)
|
||||
|
||||
|
||||
class OmegaTraceNormalization(pl.Callback):
|
||||
'''
|
||||
Trace normalization of the Omega Matrix.
|
||||
'''
|
||||
__epsilon = torch.finfo(torch.float32).eps
|
||||
|
||||
def on_train_epoch_end(self, trainer: "pl.Trainer",
|
||||
pl_module: OmegaComparisonMixin) -> None:
|
||||
|
||||
omega = pl_module.parameter_omega
|
||||
denominator = torch.sqrt(torch.trace(omega.T @ omega))
|
||||
logging.debug(
|
||||
"Apply Omega Trace Normalization: demoninator=%f",
|
||||
denominator.item(),
|
||||
)
|
||||
pl_module.parameter_omega = omega / (denominator + self.__epsilon)
|
||||
|
@@ -41,7 +41,7 @@ class GMLVQ(
|
||||
comparison_args: Keyword arguments for the comparison function. Override Default: {}.
|
||||
"""
|
||||
comparison_fn: Callable = omega_distance
|
||||
comparison_args: dict = field(default_factory=lambda: dict())
|
||||
comparison_args: dict = field(default_factory=dict)
|
||||
optimizer: type[torch.optim.Optimizer] = torch.optim.Adam
|
||||
|
||||
lr: dict = field(default_factory=lambda: dict(
|
||||
|
7
setup.py
7
setup.py
@@ -10,6 +10,8 @@
|
||||
|
||||
ProtoTorch models Plugin Package
|
||||
"""
|
||||
from pathlib import Path
|
||||
|
||||
from pkg_resources import safe_name
|
||||
from setuptools import find_namespace_packages, setup
|
||||
|
||||
@@ -18,8 +20,7 @@ PLUGIN_NAME = "models"
|
||||
PROJECT_URL = "https://github.com/si-cim/prototorch_models"
|
||||
DOWNLOAD_URL = "https://github.com/si-cim/prototorch_models.git"
|
||||
|
||||
with open("README.md", "r") as fh:
|
||||
long_description = fh.read()
|
||||
long_description = Path("README.md").read_text(encoding='utf8')
|
||||
|
||||
INSTALL_REQUIRES = [
|
||||
"prototorch>=0.7.3",
|
||||
@@ -55,7 +56,7 @@ ALL = CLI + DEV + DOCS + EXAMPLES + TESTS
|
||||
|
||||
setup(
|
||||
name=safe_name("prototorch_" + PLUGIN_NAME),
|
||||
version="1.0.0-a6",
|
||||
version="1.0.0-a8",
|
||||
description="Pre-packaged prototype-based "
|
||||
"machine learning models using ProtoTorch and PyTorch-Lightning.",
|
||||
long_description=long_description,
|
||||
|
13
tests/test_models.py
Normal file
13
tests/test_models.py
Normal file
@@ -0,0 +1,13 @@
|
||||
"""prototorch.models test suite."""
|
||||
|
||||
import prototorch as pt
|
||||
from prototorch.models.library import GLVQ
|
||||
|
||||
|
||||
def test_glvq_model_build():
|
||||
hparams = GLVQ.HyperParameters(
|
||||
distribution=dict(num_classes=2, per_class=1),
|
||||
component_initializer=pt.initializers.RNCI(2),
|
||||
)
|
||||
|
||||
model = GLVQ(hparams=hparams)
|
Reference in New Issue
Block a user