Compare commits
	
		
			11 Commits
		
	
	
		
			v1.0.0a6
			...
			feature/be
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 
						 | 
					9bb2e20dce | ||
| 
						 | 
					6748951b63 | ||
| 
						 | 
					c547af728b | ||
| 
						 | 
					482044ec87 | ||
| 
						 | 
					45f01f39d4 | ||
| 
						 | 
					9ab864fbdf | ||
| 
						 | 
					365e0fb931 | ||
| 
						 | 
					ba50dfba50 | ||
| 
						 | 
					16ca409f07 | ||
| 
						 | 
					c3cad19853 | ||
| 
						 | 
					ec294bdd37 | 
@@ -1,5 +1,5 @@
 | 
			
		||||
[bumpversion]
 | 
			
		||||
current_version = 1.0.0a6
 | 
			
		||||
current_version = 1.0.0a8
 | 
			
		||||
commit = True
 | 
			
		||||
tag = True
 | 
			
		||||
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)((?P<release>[a-zA-Z0-9_.-]+))?
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										5
									
								
								.github/workflows/pythonapp.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										5
									
								
								.github/workflows/pythonapp.yml
									
									
									
									
										vendored
									
									
								
							@@ -21,7 +21,7 @@ jobs:
 | 
			
		||||
      run: |
 | 
			
		||||
        python -m pip install --upgrade pip
 | 
			
		||||
        pip install .[all]
 | 
			
		||||
    - uses: pre-commit/action@v2.0.3
 | 
			
		||||
    - uses: pre-commit/action@v3.0.0
 | 
			
		||||
  compatibility:
 | 
			
		||||
    needs: style
 | 
			
		||||
    strategy:
 | 
			
		||||
@@ -36,7 +36,8 @@ jobs:
 | 
			
		||||
          python-version: "3.8"
 | 
			
		||||
        - os: windows-latest
 | 
			
		||||
          python-version: "3.9"
 | 
			
		||||
 | 
			
		||||
        - os: windows-latest
 | 
			
		||||
          python-version: "3.11"
 | 
			
		||||
    runs-on: ${{ matrix.os }}
 | 
			
		||||
    steps:
 | 
			
		||||
    - uses: actions/checkout@v2
 | 
			
		||||
 
 | 
			
		||||
@@ -3,7 +3,7 @@
 | 
			
		||||
 | 
			
		||||
repos:
 | 
			
		||||
- repo: https://github.com/pre-commit/pre-commit-hooks
 | 
			
		||||
  rev: v4.2.0
 | 
			
		||||
  rev: v4.3.0
 | 
			
		||||
  hooks:
 | 
			
		||||
  - id: trailing-whitespace
 | 
			
		||||
    exclude: (^\.bumpversion\.cfg$|cli_messages\.py)
 | 
			
		||||
@@ -14,7 +14,7 @@ repos:
 | 
			
		||||
  - id: check-case-conflict
 | 
			
		||||
 | 
			
		||||
- repo: https://github.com/myint/autoflake
 | 
			
		||||
  rev: v1.4
 | 
			
		||||
  rev: v1.7.7
 | 
			
		||||
  hooks:
 | 
			
		||||
  - id: autoflake
 | 
			
		||||
 | 
			
		||||
@@ -24,7 +24,7 @@ repos:
 | 
			
		||||
  - id: isort
 | 
			
		||||
 | 
			
		||||
- repo: https://github.com/pre-commit/mirrors-mypy
 | 
			
		||||
  rev: v0.950
 | 
			
		||||
  rev: v0.982
 | 
			
		||||
  hooks:
 | 
			
		||||
  - id: mypy
 | 
			
		||||
    files: prototorch
 | 
			
		||||
@@ -43,7 +43,7 @@ repos:
 | 
			
		||||
  - id: python-check-blanket-noqa
 | 
			
		||||
 | 
			
		||||
- repo: https://github.com/asottile/pyupgrade
 | 
			
		||||
  rev: v2.32.1
 | 
			
		||||
  rev: v3.1.0
 | 
			
		||||
  hooks:
 | 
			
		||||
  - id: pyupgrade
 | 
			
		||||
 | 
			
		||||
@@ -52,3 +52,8 @@ repos:
 | 
			
		||||
  hooks:
 | 
			
		||||
  - id: gitlint
 | 
			
		||||
    args: [--contrib=CT1, --ignore=B6, --msg-filename]
 | 
			
		||||
 | 
			
		||||
- repo: https://github.com/dosisod/refurb
 | 
			
		||||
  rev: v1.4.0
 | 
			
		||||
  hooks:
 | 
			
		||||
    - id: refurb
 | 
			
		||||
 
 | 
			
		||||
@@ -23,7 +23,7 @@ author = "Jensun Ravichandran"
 | 
			
		||||
 | 
			
		||||
# The full version, including alpha/beta/rc tags
 | 
			
		||||
#
 | 
			
		||||
release = "1.0.0-a6"
 | 
			
		||||
release = "1.0.0-a8"
 | 
			
		||||
 | 
			
		||||
# -- General configuration ---------------------------------------------------
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -97,6 +97,13 @@ def main():
 | 
			
		||||
        step=Steps.VALIDATION,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    accuracy = LogTorchmetricCallback(
 | 
			
		||||
        'validation_accuracy',
 | 
			
		||||
        torchmetrics.Accuracy,
 | 
			
		||||
        num_classes=3,
 | 
			
		||||
        step=Steps.VALIDATION,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    es = EarlyStopping(
 | 
			
		||||
        monitor=stopping_criterion.name,
 | 
			
		||||
        mode="max",
 | 
			
		||||
@@ -111,6 +118,7 @@ def main():
 | 
			
		||||
        callbacks=[
 | 
			
		||||
            vis,
 | 
			
		||||
            recall,
 | 
			
		||||
            accuracy,
 | 
			
		||||
            stopping_criterion,
 | 
			
		||||
            es,
 | 
			
		||||
            PlotLambdaMatrixToTensorboard(),
 | 
			
		||||
 
 | 
			
		||||
@@ -22,4 +22,4 @@ __all__ = [
 | 
			
		||||
    "GLVQLossMixin",
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
__version__ = "1.0.0-a6"
 | 
			
		||||
__version__ = "1.0.0-a8"
 | 
			
		||||
 
 | 
			
		||||
@@ -46,15 +46,15 @@ class BaseYArchitecture(pl.LightningModule):
 | 
			
		||||
    components_layer: torch.nn.Module
 | 
			
		||||
 | 
			
		||||
    def __init__(self, hparams) -> None:
 | 
			
		||||
        if type(hparams) is dict:
 | 
			
		||||
        if isinstance(hparams, dict):
 | 
			
		||||
            self.save_hyperparameters(hparams)
 | 
			
		||||
            # TODO: => Move into Component Child
 | 
			
		||||
            del hparams["initialized_proto_shape"]
 | 
			
		||||
            hparams = self.HyperParameters(**hparams)
 | 
			
		||||
        else:
 | 
			
		||||
            hparam_dict = asdict(hparams)
 | 
			
		||||
            hparam_dict["component_initializer"] = None
 | 
			
		||||
            self.save_hyperparameters(hparam_dict, )
 | 
			
		||||
            hparams_dict = asdict(hparams)
 | 
			
		||||
            hparams_dict["component_initializer"] = None
 | 
			
		||||
            self.save_hyperparameters(hparams_dict, )
 | 
			
		||||
 | 
			
		||||
        super().__init__()
 | 
			
		||||
 | 
			
		||||
@@ -72,6 +72,9 @@ class BaseYArchitecture(pl.LightningModule):
 | 
			
		||||
 | 
			
		||||
    # external API
 | 
			
		||||
    def get_competition(self, batch, components):
 | 
			
		||||
        '''
 | 
			
		||||
        Returns the output of the competition layer.
 | 
			
		||||
        '''
 | 
			
		||||
        latent_batch, latent_components = self.backbone(batch, components)
 | 
			
		||||
        # TODO: => Latent Hook
 | 
			
		||||
        comparison_tensor = self.comparison(latent_batch, latent_components)
 | 
			
		||||
@@ -79,6 +82,9 @@ class BaseYArchitecture(pl.LightningModule):
 | 
			
		||||
        return comparison_tensor
 | 
			
		||||
 | 
			
		||||
    def forward(self, batch):
 | 
			
		||||
        '''
 | 
			
		||||
        Returns the prediction.
 | 
			
		||||
        '''
 | 
			
		||||
        if isinstance(batch, torch.Tensor):
 | 
			
		||||
            batch = (batch, None)
 | 
			
		||||
        # TODO: manage different datatypes?
 | 
			
		||||
@@ -95,6 +101,9 @@ class BaseYArchitecture(pl.LightningModule):
 | 
			
		||||
        return self.forward(batch)
 | 
			
		||||
 | 
			
		||||
    def forward_comparison(self, batch):
 | 
			
		||||
        '''
 | 
			
		||||
        Returns the Output of the comparison layer.
 | 
			
		||||
        '''
 | 
			
		||||
        if isinstance(batch, torch.Tensor):
 | 
			
		||||
            batch = (batch, None)
 | 
			
		||||
        # TODO: manage different datatypes?
 | 
			
		||||
@@ -103,6 +112,9 @@ class BaseYArchitecture(pl.LightningModule):
 | 
			
		||||
        return self.get_competition(batch, components)
 | 
			
		||||
 | 
			
		||||
    def loss_forward(self, batch):
 | 
			
		||||
        '''
 | 
			
		||||
        Returns the output of the loss layer.
 | 
			
		||||
        '''
 | 
			
		||||
        # TODO: manage different datatypes?
 | 
			
		||||
        components = self.components_layer()
 | 
			
		||||
        # TODO: => Component Hook
 | 
			
		||||
@@ -115,37 +127,31 @@ class BaseYArchitecture(pl.LightningModule):
 | 
			
		||||
        """
 | 
			
		||||
        All initialization necessary for the components step.
 | 
			
		||||
        """
 | 
			
		||||
        ...
 | 
			
		||||
 | 
			
		||||
    def init_backbone(self, hparams: HyperParameters) -> None:
 | 
			
		||||
        """
 | 
			
		||||
        All initialization necessary for the backbone step.
 | 
			
		||||
        """
 | 
			
		||||
        ...
 | 
			
		||||
 | 
			
		||||
    def init_comparison(self, hparams: HyperParameters) -> None:
 | 
			
		||||
        """
 | 
			
		||||
        All initialization necessary for the comparison step.
 | 
			
		||||
        """
 | 
			
		||||
        ...
 | 
			
		||||
 | 
			
		||||
    def init_competition(self, hparams: HyperParameters) -> None:
 | 
			
		||||
        """
 | 
			
		||||
        All initialization necessary for the competition step.
 | 
			
		||||
        """
 | 
			
		||||
        ...
 | 
			
		||||
 | 
			
		||||
    def init_loss(self, hparams: HyperParameters) -> None:
 | 
			
		||||
        """
 | 
			
		||||
        All initialization necessary for the loss step.
 | 
			
		||||
        """
 | 
			
		||||
        ...
 | 
			
		||||
 | 
			
		||||
    def init_inference(self, hparams: HyperParameters) -> None:
 | 
			
		||||
        """
 | 
			
		||||
        All initialization necessary for the inference step.
 | 
			
		||||
        """
 | 
			
		||||
        ...
 | 
			
		||||
 | 
			
		||||
    # Empty Steps
 | 
			
		||||
    def components(self):
 | 
			
		||||
@@ -162,7 +168,8 @@ class BaseYArchitecture(pl.LightningModule):
 | 
			
		||||
        The backbone step receives the data batch and the components.
 | 
			
		||||
        It can transform both by an arbitrary function.
 | 
			
		||||
 | 
			
		||||
        It returns the transformed batch and components, each of the same length as the original input.
 | 
			
		||||
        It returns the transformed batch and components,
 | 
			
		||||
        each of the same length as the original input.
 | 
			
		||||
        """
 | 
			
		||||
        return batch, components
 | 
			
		||||
 | 
			
		||||
@@ -211,6 +218,9 @@ class BaseYArchitecture(pl.LightningModule):
 | 
			
		||||
        step: str = Steps.TRAINING,
 | 
			
		||||
        **metric_kwargs,
 | 
			
		||||
    ):
 | 
			
		||||
        '''
 | 
			
		||||
        Register a callback for evaluating a torchmetric.
 | 
			
		||||
        '''
 | 
			
		||||
        if step == Steps.PREDICT:
 | 
			
		||||
            raise ValueError("Prediction metrics are not supported.")
 | 
			
		||||
 | 
			
		||||
@@ -224,10 +234,10 @@ class BaseYArchitecture(pl.LightningModule):
 | 
			
		||||
        # Prediction Metrics
 | 
			
		||||
        preds = self(batch)
 | 
			
		||||
 | 
			
		||||
        x, y = batch
 | 
			
		||||
        _, y = batch
 | 
			
		||||
        for metric in self.registered_metrics[step]:
 | 
			
		||||
            instance = self.registered_metrics[step][metric].to(self.device)
 | 
			
		||||
            instance(y, preds)
 | 
			
		||||
            instance(y, preds.reshape(y.shape))
 | 
			
		||||
 | 
			
		||||
    def update_metrics_epoch(self, step):
 | 
			
		||||
        for metric in self.registered_metrics[step]:
 | 
			
		||||
@@ -247,7 +257,7 @@ class BaseYArchitecture(pl.LightningModule):
 | 
			
		||||
 | 
			
		||||
        return self.loss_forward(batch)
 | 
			
		||||
 | 
			
		||||
    def training_epoch_end(self, outs) -> None:
 | 
			
		||||
    def training_epoch_end(self, outputs) -> None:
 | 
			
		||||
        self.update_metrics_epoch(Steps.TRAINING)
 | 
			
		||||
 | 
			
		||||
    # >>>> Validation
 | 
			
		||||
@@ -256,7 +266,7 @@ class BaseYArchitecture(pl.LightningModule):
 | 
			
		||||
 | 
			
		||||
        return self.loss_forward(batch)
 | 
			
		||||
 | 
			
		||||
    def validation_epoch_end(self, outs) -> None:
 | 
			
		||||
    def validation_epoch_end(self, outputs) -> None:
 | 
			
		||||
        self.update_metrics_epoch(Steps.VALIDATION)
 | 
			
		||||
 | 
			
		||||
    # >>>> Test
 | 
			
		||||
@@ -264,7 +274,7 @@ class BaseYArchitecture(pl.LightningModule):
 | 
			
		||||
        self.update_metrics_step(batch, Steps.TEST)
 | 
			
		||||
        return self.loss_forward(batch)
 | 
			
		||||
 | 
			
		||||
    def test_epoch_end(self, outs) -> None:
 | 
			
		||||
    def test_epoch_end(self, outputs) -> None:
 | 
			
		||||
        self.update_metrics_epoch(Steps.TEST)
 | 
			
		||||
 | 
			
		||||
    # >>>> Prediction
 | 
			
		||||
 
 | 
			
		||||
@@ -32,9 +32,9 @@ class SimpleComparisonMixin(BaseYArchitecture):
 | 
			
		||||
        comparison_args: Keyword arguments for the comparison function. Default: {}.
 | 
			
		||||
        """
 | 
			
		||||
        comparison_fn: Callable = euclidean_distance
 | 
			
		||||
        comparison_args: dict = field(default_factory=lambda: dict())
 | 
			
		||||
        comparison_args: dict = field(default_factory=dict)
 | 
			
		||||
 | 
			
		||||
        comparison_parameters: dict = field(default_factory=lambda: dict())
 | 
			
		||||
        comparison_parameters: dict = field(default_factory=dict)
 | 
			
		||||
 | 
			
		||||
    # Steps
 | 
			
		||||
    # ----------------------------------------------------------------------------------------------
 | 
			
		||||
@@ -44,7 +44,7 @@ class SimpleComparisonMixin(BaseYArchitecture):
 | 
			
		||||
            **hparams.comparison_args,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        self.comparison_kwargs: dict[str, Tensor] = dict()
 | 
			
		||||
        self.comparison_kwargs: dict[str, Tensor] = {}
 | 
			
		||||
 | 
			
		||||
    def comparison(self, batch, components):
 | 
			
		||||
        comp_tensor, _ = components
 | 
			
		||||
@@ -86,7 +86,7 @@ class OmegaComparisonMixin(SimpleComparisonMixin):
 | 
			
		||||
        latent_dim: int = 2
 | 
			
		||||
        omega_initializer: type[
 | 
			
		||||
            AbstractLinearTransformInitializer] = EyeLinearTransformInitializer
 | 
			
		||||
        omega_initializer_kwargs: dict = field(default_factory=lambda: dict())
 | 
			
		||||
        omega_initializer_kwargs: dict = field(default_factory=dict)
 | 
			
		||||
 | 
			
		||||
    # Steps
 | 
			
		||||
    # ----------------------------------------------------------------------------------------------
 | 
			
		||||
@@ -137,3 +137,12 @@ class OmegaComparisonMixin(SimpleComparisonMixin):
 | 
			
		||||
        '''
 | 
			
		||||
        lam = self.lambda_matrix
 | 
			
		||||
        return lam.abs().sum(0)
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def parameter_omega(self):
 | 
			
		||||
        return self._omega
 | 
			
		||||
 | 
			
		||||
    @parameter_omega.setter
 | 
			
		||||
    def parameter_omega(self, new_omega):
 | 
			
		||||
        with torch.no_grad():
 | 
			
		||||
            self._omega.data.copy_(new_omega)
 | 
			
		||||
 
 | 
			
		||||
@@ -46,7 +46,7 @@ class MultipleLearningRateMixin(BaseYArchitecture):
 | 
			
		||||
        lr: The learning rate. Default: 0.1.
 | 
			
		||||
        optimizer: The optimizer to use. Default: torch.optim.Adam.
 | 
			
		||||
        """
 | 
			
		||||
        lr: dict = field(default_factory=lambda: dict())
 | 
			
		||||
        lr: dict = field(default_factory=dict)
 | 
			
		||||
        optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam
 | 
			
		||||
 | 
			
		||||
    # Hooks
 | 
			
		||||
 
 | 
			
		||||
@@ -1,13 +1,15 @@
 | 
			
		||||
import logging
 | 
			
		||||
import warnings
 | 
			
		||||
from enum import Enum
 | 
			
		||||
from typing import Optional, Type
 | 
			
		||||
 | 
			
		||||
import matplotlib.pyplot as plt
 | 
			
		||||
import numpy as np
 | 
			
		||||
import pytorch_lightning as pl
 | 
			
		||||
import torch
 | 
			
		||||
import torchmetrics
 | 
			
		||||
from matplotlib import pyplot as plt
 | 
			
		||||
from prototorch.models.architectures.base import BaseYArchitecture, Steps
 | 
			
		||||
from prototorch.models.architectures.comparison import OmegaComparisonMixin
 | 
			
		||||
from prototorch.models.library.gmlvq import GMLVQ
 | 
			
		||||
from prototorch.models.vis import Vis2DAbstract
 | 
			
		||||
from prototorch.utils.utils import mesh2d
 | 
			
		||||
@@ -36,12 +38,14 @@ class LogTorchmetricCallback(pl.Callback):
 | 
			
		||||
        name,
 | 
			
		||||
        metric: Type[torchmetrics.Metric],
 | 
			
		||||
        step: str = Steps.TRAINING,
 | 
			
		||||
        on_epoch=True,
 | 
			
		||||
        **metric_kwargs,
 | 
			
		||||
    ) -> None:
 | 
			
		||||
        self.name = name
 | 
			
		||||
        self.metric = metric
 | 
			
		||||
        self.metric_kwargs = metric_kwargs
 | 
			
		||||
        self.step = step
 | 
			
		||||
        self.on_epoch = on_epoch
 | 
			
		||||
 | 
			
		||||
    def setup(
 | 
			
		||||
        self,
 | 
			
		||||
@@ -57,7 +61,12 @@ class LogTorchmetricCallback(pl.Callback):
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def __call__(self, value, pl_module: BaseYArchitecture):
 | 
			
		||||
        pl_module.log(self.name, value)
 | 
			
		||||
        pl_module.log(
 | 
			
		||||
            self.name,
 | 
			
		||||
            value,
 | 
			
		||||
            on_epoch=self.on_epoch,
 | 
			
		||||
            on_step=(not self.on_epoch),
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class LogConfusionMatrix(LogTorchmetricCallback):
 | 
			
		||||
@@ -207,7 +216,7 @@ class PlotLambdaMatrixToTensorboard(pl.Callback):
 | 
			
		||||
        # add to tensorboard
 | 
			
		||||
        if isinstance(trainer.logger, TensorBoardLogger):
 | 
			
		||||
            trainer.logger.experiment.add_figure(
 | 
			
		||||
                f"lambda_matrix",
 | 
			
		||||
                "lambda_matrix",
 | 
			
		||||
                self.fig,
 | 
			
		||||
                trainer.global_step,
 | 
			
		||||
            )
 | 
			
		||||
@@ -215,3 +224,84 @@ class PlotLambdaMatrixToTensorboard(pl.Callback):
 | 
			
		||||
            warnings.warn(
 | 
			
		||||
                f"{self.__class__.__name__} is not compatible with {trainer.logger.__class__.__name__} as logger. Use TensorBoardLogger instead."
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Profiles(Enum):
 | 
			
		||||
    '''
 | 
			
		||||
    Available Profiles
 | 
			
		||||
    '''
 | 
			
		||||
    RELEVANCE = 'relevance'
 | 
			
		||||
    INFLUENCE = 'influence'
 | 
			
		||||
 | 
			
		||||
    def __str__(self):
 | 
			
		||||
        return str(self.value)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class PlotMatrixProfiles(pl.Callback):
 | 
			
		||||
 | 
			
		||||
    def __init__(self, profile=Profiles.INFLUENCE, cmap='seismic') -> None:
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        self.cmap = cmap
 | 
			
		||||
        self.profile = profile
 | 
			
		||||
 | 
			
		||||
    def on_train_start(self, trainer, pl_module: GMLVQ):
 | 
			
		||||
        '''
 | 
			
		||||
        Plot initial profile.
 | 
			
		||||
        '''
 | 
			
		||||
        self._plot_profile(trainer, pl_module)
 | 
			
		||||
 | 
			
		||||
    def on_train_epoch_end(self, trainer, pl_module: GMLVQ):
 | 
			
		||||
        '''
 | 
			
		||||
        Plot after every epoch.
 | 
			
		||||
        '''
 | 
			
		||||
        self._plot_profile(trainer, pl_module)
 | 
			
		||||
 | 
			
		||||
    def _plot_profile(self, trainer, pl_module: GMLVQ):
 | 
			
		||||
 | 
			
		||||
        fig, ax = plt.subplots(1, 1)
 | 
			
		||||
 | 
			
		||||
        # plot lambda matrix
 | 
			
		||||
        l_matrix = torch.abs(pl_module.lambda_matrix)
 | 
			
		||||
 | 
			
		||||
        if self.profile == Profiles.RELEVANCE:
 | 
			
		||||
            profile_value = l_matrix.diag()
 | 
			
		||||
        elif self.profile == Profiles.INFLUENCE:
 | 
			
		||||
            profile_value = l_matrix.sum(0)
 | 
			
		||||
 | 
			
		||||
        # plot lambda matrix
 | 
			
		||||
        ax.plot(profile_value.detach().numpy())
 | 
			
		||||
 | 
			
		||||
        # add title
 | 
			
		||||
        ax.set_title(f'{self.profile} profile')
 | 
			
		||||
 | 
			
		||||
        # add to tensorboard
 | 
			
		||||
        if isinstance(trainer.logger, TensorBoardLogger):
 | 
			
		||||
            trainer.logger.experiment.add_figure(
 | 
			
		||||
                f"{self.profile}_matrix",
 | 
			
		||||
                fig,
 | 
			
		||||
                trainer.global_step,
 | 
			
		||||
            )
 | 
			
		||||
        else:
 | 
			
		||||
            class_name = self.__class__.__name__
 | 
			
		||||
            logger_name = trainer.logger.__class__.__name__
 | 
			
		||||
            warnings.warn(
 | 
			
		||||
                f"{class_name} is not compatible with {logger_name} as logger. Use TensorBoardLogger instead."
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class OmegaTraceNormalization(pl.Callback):
 | 
			
		||||
    '''
 | 
			
		||||
    Trace normalization of the Omega Matrix.
 | 
			
		||||
    '''
 | 
			
		||||
    __epsilon = torch.finfo(torch.float32).eps
 | 
			
		||||
 | 
			
		||||
    def on_train_epoch_end(self, trainer: "pl.Trainer",
 | 
			
		||||
                           pl_module: OmegaComparisonMixin) -> None:
 | 
			
		||||
 | 
			
		||||
        omega = pl_module.parameter_omega
 | 
			
		||||
        denominator = torch.sqrt(torch.trace(omega.T @ omega))
 | 
			
		||||
        logging.debug(
 | 
			
		||||
            "Apply Omega Trace Normalization: demoninator=%f",
 | 
			
		||||
            denominator.item(),
 | 
			
		||||
        )
 | 
			
		||||
        pl_module.parameter_omega = omega / (denominator + self.__epsilon)
 | 
			
		||||
 
 | 
			
		||||
@@ -41,7 +41,7 @@ class GMLVQ(
 | 
			
		||||
        comparison_args: Keyword arguments for the comparison function. Override Default: {}.
 | 
			
		||||
        """
 | 
			
		||||
        comparison_fn: Callable = omega_distance
 | 
			
		||||
        comparison_args: dict = field(default_factory=lambda: dict())
 | 
			
		||||
        comparison_args: dict = field(default_factory=dict)
 | 
			
		||||
        optimizer: type[torch.optim.Optimizer] = torch.optim.Adam
 | 
			
		||||
 | 
			
		||||
        lr: dict = field(default_factory=lambda: dict(
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										7
									
								
								setup.py
									
									
									
									
									
								
							
							
						
						
									
										7
									
								
								setup.py
									
									
									
									
									
								
							@@ -10,6 +10,8 @@
 | 
			
		||||
 | 
			
		||||
ProtoTorch models Plugin Package
 | 
			
		||||
"""
 | 
			
		||||
from pathlib import Path
 | 
			
		||||
 | 
			
		||||
from pkg_resources import safe_name
 | 
			
		||||
from setuptools import find_namespace_packages, setup
 | 
			
		||||
 | 
			
		||||
@@ -18,8 +20,7 @@ PLUGIN_NAME = "models"
 | 
			
		||||
PROJECT_URL = "https://github.com/si-cim/prototorch_models"
 | 
			
		||||
DOWNLOAD_URL = "https://github.com/si-cim/prototorch_models.git"
 | 
			
		||||
 | 
			
		||||
with open("README.md", "r") as fh:
 | 
			
		||||
    long_description = fh.read()
 | 
			
		||||
long_description = Path("README.md").read_text(encoding='utf8')
 | 
			
		||||
 | 
			
		||||
INSTALL_REQUIRES = [
 | 
			
		||||
    "prototorch>=0.7.3",
 | 
			
		||||
@@ -55,7 +56,7 @@ ALL = CLI + DEV + DOCS + EXAMPLES + TESTS
 | 
			
		||||
 | 
			
		||||
setup(
 | 
			
		||||
    name=safe_name("prototorch_" + PLUGIN_NAME),
 | 
			
		||||
    version="1.0.0-a6",
 | 
			
		||||
    version="1.0.0-a8",
 | 
			
		||||
    description="Pre-packaged prototype-based "
 | 
			
		||||
    "machine learning models using ProtoTorch and PyTorch-Lightning.",
 | 
			
		||||
    long_description=long_description,
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										13
									
								
								tests/test_models.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										13
									
								
								tests/test_models.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,13 @@
 | 
			
		||||
"""prototorch.models test suite."""
 | 
			
		||||
 | 
			
		||||
import prototorch as pt
 | 
			
		||||
from prototorch.models.library import GLVQ
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_glvq_model_build():
 | 
			
		||||
    hparams = GLVQ.HyperParameters(
 | 
			
		||||
        distribution=dict(num_classes=2, per_class=1),
 | 
			
		||||
        component_initializer=pt.initializers.RNCI(2),
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    model = GLVQ(hparams=hparams)
 | 
			
		||||
		Reference in New Issue
	
	Block a user