test(githooks): Add githooks for automatic commit checks
This commit is contained in:
		
							
								
								
									
										54
									
								
								.pre-commit-config.yaml
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										54
									
								
								.pre-commit-config.yaml
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,54 @@
 | 
				
			|||||||
 | 
					# See https://pre-commit.com for more information
 | 
				
			||||||
 | 
					# See https://pre-commit.com/hooks.html for more hooks
 | 
				
			||||||
 | 
					repos:
 | 
				
			||||||
 | 
					-   repo: https://github.com/pre-commit/pre-commit-hooks
 | 
				
			||||||
 | 
					    rev: v4.0.1
 | 
				
			||||||
 | 
					    hooks:
 | 
				
			||||||
 | 
					    -   id: trailing-whitespace
 | 
				
			||||||
 | 
					    -   id: end-of-file-fixer
 | 
				
			||||||
 | 
					    -   id: check-yaml
 | 
				
			||||||
 | 
					    -   id: check-added-large-files
 | 
				
			||||||
 | 
					    -   id: check-ast
 | 
				
			||||||
 | 
					    -   id: check-case-conflict
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					- repo: https://github.com/myint/autoflake
 | 
				
			||||||
 | 
					  rev: v1.4
 | 
				
			||||||
 | 
					  hooks:
 | 
				
			||||||
 | 
					  -   id: autoflake
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					- repo: http://github.com/PyCQA/isort
 | 
				
			||||||
 | 
					  rev: 5.8.0
 | 
				
			||||||
 | 
					  hooks:
 | 
				
			||||||
 | 
					  -   id: isort
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					-   repo: https://github.com/pre-commit/mirrors-mypy
 | 
				
			||||||
 | 
					    rev: 'v0.902'
 | 
				
			||||||
 | 
					    hooks:
 | 
				
			||||||
 | 
					    -   id: mypy
 | 
				
			||||||
 | 
					        files: prototorch
 | 
				
			||||||
 | 
					        additional_dependencies: [types-pkg_resources]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					-   repo: https://github.com/pre-commit/mirrors-yapf
 | 
				
			||||||
 | 
					    rev: 'v0.31.0'  # Use the sha / tag you want to point at
 | 
				
			||||||
 | 
					    hooks:
 | 
				
			||||||
 | 
					    -   id: yapf
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					-   repo: https://github.com/pre-commit/pygrep-hooks
 | 
				
			||||||
 | 
					    rev: v1.9.0  # Use the ref you want to point at
 | 
				
			||||||
 | 
					    hooks:
 | 
				
			||||||
 | 
					    -   id: python-use-type-annotations
 | 
				
			||||||
 | 
					    -   id: python-no-log-warn
 | 
				
			||||||
 | 
					    -   id: python-check-blanket-noqa
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					-   repo: https://github.com/asottile/pyupgrade
 | 
				
			||||||
 | 
					    rev: v2.19.4
 | 
				
			||||||
 | 
					    hooks:
 | 
				
			||||||
 | 
					    -   id: pyupgrade
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					-   repo: https://github.com/jorisroovers/gitlint
 | 
				
			||||||
 | 
					    rev: "v0.15.1"
 | 
				
			||||||
 | 
					    hooks:
 | 
				
			||||||
 | 
					    -   id: gitlint
 | 
				
			||||||
 | 
					        args: [--contrib=CT1, --ignore=B6, --msg-filename]
 | 
				
			||||||
							
								
								
									
										13
									
								
								README.md
									
									
									
									
									
								
							
							
						
						
									
										13
									
								
								README.md
									
									
									
									
									
								
							@@ -18,6 +18,19 @@ pip install prototorch_models
 | 
				
			|||||||
of** [ProtoTorch](https://github.com/si-cim/prototorch). The plugin should then
 | 
					of** [ProtoTorch](https://github.com/si-cim/prototorch). The plugin should then
 | 
				
			||||||
be available for use in your Python environment as `prototorch.models`.
 | 
					be available for use in your Python environment as `prototorch.models`.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					## Contribution
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					This repository contains definition for [git hooks](https://githooks.com).
 | 
				
			||||||
 | 
					[Pre-commit](https://pre-commit.com) gets installed as development dependency with prototorch.
 | 
				
			||||||
 | 
					Please install the hooks by running
 | 
				
			||||||
 | 
					```bash
 | 
				
			||||||
 | 
					pre-commit install
 | 
				
			||||||
 | 
					pre-commit install --hook-type commit-msg
 | 
				
			||||||
 | 
					```
 | 
				
			||||||
 | 
					before creating the first commit.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
## Available models
 | 
					## Available models
 | 
				
			||||||
 | 
					
 | 
				
			||||||
### LVQ Family
 | 
					### LVQ Family
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -70,4 +70,3 @@
 | 
				
			|||||||
    year="2018",
 | 
					    year="2018",
 | 
				
			||||||
    publisher="Springer International Publishing",
 | 
					    publisher="Springer International Publishing",
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					 | 
				
			||||||
 
 | 
				
			|||||||
@@ -1,12 +1,20 @@
 | 
				
			|||||||
"""GMLVQ example using the MNIST dataset."""
 | 
					"""GMLVQ example using the MNIST dataset."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from prototorch.models import ImageGMLVQ
 | 
					import torch
 | 
				
			||||||
from prototorch.models.data import train_on_mnist
 | 
					 | 
				
			||||||
from pytorch_lightning.utilities.cli import LightningCLI
 | 
					from pytorch_lightning.utilities.cli import LightningCLI
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import prototorch as pt
 | 
				
			||||||
class GMLVQMNIST(train_on_mnist(batch_size=64), ImageGMLVQ):
 | 
					from prototorch.models import ImageGMLVQ
 | 
				
			||||||
    """Model Definition."""
 | 
					from prototorch.models.abstract import PrototypeModel
 | 
				
			||||||
 | 
					from prototorch.models.data import MNISTDataModule
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
cli = LightningCLI(GMLVQMNIST)
 | 
					class ExperimentClass(ImageGMLVQ):
 | 
				
			||||||
 | 
					    def __init__(self, hparams, **kwargs):
 | 
				
			||||||
 | 
					        super().__init__(hparams,
 | 
				
			||||||
 | 
					                         optimizer=torch.optim.Adam,
 | 
				
			||||||
 | 
					                         prototype_initializer=pt.components.zeros(28 * 28),
 | 
				
			||||||
 | 
					                         **kwargs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					cli = LightningCLI(ImageGMLVQ, MNISTDataModule)
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -1,9 +1,11 @@
 | 
				
			|||||||
model:
 | 
					model:
 | 
				
			||||||
  hparams:
 | 
					  hparams:
 | 
				
			||||||
 | 
					    input_dim: 784
 | 
				
			||||||
 | 
					    latent_dim: 784
 | 
				
			||||||
    distribution:
 | 
					    distribution:
 | 
				
			||||||
      num_classes: 10
 | 
					      num_classes: 10
 | 
				
			||||||
      prototypes_per_class: 2
 | 
					      prototypes_per_class: 2
 | 
				
			||||||
    input_dim: 784
 | 
					 | 
				
			||||||
    latent_dim: 784
 | 
					 | 
				
			||||||
    proto_lr: 0.01
 | 
					    proto_lr: 0.01
 | 
				
			||||||
    bb_lr: 0.01
 | 
					    bb_lr: 0.01
 | 
				
			||||||
 | 
					data:
 | 
				
			||||||
 | 
					  batch_size: 32
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -2,10 +2,11 @@
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
import argparse
 | 
					import argparse
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import prototorch as pt
 | 
					 | 
				
			||||||
import pytorch_lightning as pl
 | 
					import pytorch_lightning as pl
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import prototorch as pt
 | 
				
			||||||
 | 
					
 | 
				
			||||||
if __name__ == "__main__":
 | 
					if __name__ == "__main__":
 | 
				
			||||||
    # Command-line arguments
 | 
					    # Command-line arguments
 | 
				
			||||||
    parser = argparse.ArgumentParser()
 | 
					    parser = argparse.ArgumentParser()
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -2,11 +2,12 @@
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
import argparse
 | 
					import argparse
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import prototorch as pt
 | 
					 | 
				
			||||||
import pytorch_lightning as pl
 | 
					import pytorch_lightning as pl
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
from torch.optim.lr_scheduler import ExponentialLR
 | 
					from torch.optim.lr_scheduler import ExponentialLR
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import prototorch as pt
 | 
				
			||||||
 | 
					
 | 
				
			||||||
if __name__ == "__main__":
 | 
					if __name__ == "__main__":
 | 
				
			||||||
    # Command-line arguments
 | 
					    # Command-line arguments
 | 
				
			||||||
    parser = argparse.ArgumentParser()
 | 
					    parser = argparse.ArgumentParser()
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -2,10 +2,11 @@
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
import argparse
 | 
					import argparse
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import prototorch as pt
 | 
					 | 
				
			||||||
import pytorch_lightning as pl
 | 
					import pytorch_lightning as pl
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import prototorch as pt
 | 
				
			||||||
 | 
					
 | 
				
			||||||
if __name__ == "__main__":
 | 
					if __name__ == "__main__":
 | 
				
			||||||
    # Command-line arguments
 | 
					    # Command-line arguments
 | 
				
			||||||
    parser = argparse.ArgumentParser()
 | 
					    parser = argparse.ArgumentParser()
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -2,10 +2,11 @@
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
import argparse
 | 
					import argparse
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import prototorch as pt
 | 
					 | 
				
			||||||
import pytorch_lightning as pl
 | 
					import pytorch_lightning as pl
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import prototorch as pt
 | 
				
			||||||
 | 
					
 | 
				
			||||||
if __name__ == "__main__":
 | 
					if __name__ == "__main__":
 | 
				
			||||||
    # Command-line arguments
 | 
					    # Command-line arguments
 | 
				
			||||||
    parser = argparse.ArgumentParser()
 | 
					    parser = argparse.ArgumentParser()
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -2,11 +2,12 @@
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
import argparse
 | 
					import argparse
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import prototorch as pt
 | 
					 | 
				
			||||||
import pytorch_lightning as pl
 | 
					import pytorch_lightning as pl
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
from sklearn.datasets import load_iris
 | 
					from sklearn.datasets import load_iris
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import prototorch as pt
 | 
				
			||||||
 | 
					
 | 
				
			||||||
if __name__ == "__main__":
 | 
					if __name__ == "__main__":
 | 
				
			||||||
    # Command-line arguments
 | 
					    # Command-line arguments
 | 
				
			||||||
    parser = argparse.ArgumentParser()
 | 
					    parser = argparse.ArgumentParser()
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -2,11 +2,12 @@
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
import argparse
 | 
					import argparse
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import prototorch as pt
 | 
					 | 
				
			||||||
import pytorch_lightning as pl
 | 
					import pytorch_lightning as pl
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
from matplotlib import pyplot as plt
 | 
					from matplotlib import pyplot as plt
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import prototorch as pt
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def hex_to_rgb(hex_values):
 | 
					def hex_to_rgb(hex_values):
 | 
				
			||||||
    for v in hex_values:
 | 
					    for v in hex_values:
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -2,10 +2,11 @@
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
import argparse
 | 
					import argparse
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import prototorch as pt
 | 
					 | 
				
			||||||
import pytorch_lightning as pl
 | 
					import pytorch_lightning as pl
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import prototorch as pt
 | 
				
			||||||
 | 
					
 | 
				
			||||||
if __name__ == "__main__":
 | 
					if __name__ == "__main__":
 | 
				
			||||||
    # Command-line arguments
 | 
					    # Command-line arguments
 | 
				
			||||||
    parser = argparse.ArgumentParser()
 | 
					    parser = argparse.ArgumentParser()
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -3,10 +3,11 @@
 | 
				
			|||||||
import argparse
 | 
					import argparse
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import matplotlib.pyplot as plt
 | 
					import matplotlib.pyplot as plt
 | 
				
			||||||
import prototorch as pt
 | 
					 | 
				
			||||||
import pytorch_lightning as pl
 | 
					import pytorch_lightning as pl
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import prototorch as pt
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def plot_matrix(matrix):
 | 
					def plot_matrix(matrix):
 | 
				
			||||||
    title = "Lambda matrix"
 | 
					    title = "Lambda matrix"
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -2,13 +2,14 @@
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
import argparse
 | 
					import argparse
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import prototorch as pt
 | 
					 | 
				
			||||||
import pytorch_lightning as pl
 | 
					import pytorch_lightning as pl
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
from sklearn.datasets import load_iris
 | 
					from sklearn.datasets import load_iris
 | 
				
			||||||
from sklearn.preprocessing import StandardScaler
 | 
					from sklearn.preprocessing import StandardScaler
 | 
				
			||||||
from torch.optim.lr_scheduler import ExponentialLR
 | 
					from torch.optim.lr_scheduler import ExponentialLR
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import prototorch as pt
 | 
				
			||||||
 | 
					
 | 
				
			||||||
if __name__ == "__main__":
 | 
					if __name__ == "__main__":
 | 
				
			||||||
    # Command-line arguments
 | 
					    # Command-line arguments
 | 
				
			||||||
    parser = argparse.ArgumentParser()
 | 
					    parser = argparse.ArgumentParser()
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -2,11 +2,12 @@
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
import argparse
 | 
					import argparse
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import prototorch as pt
 | 
					 | 
				
			||||||
import pytorch_lightning as pl
 | 
					import pytorch_lightning as pl
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
from torchvision.transforms import Lambda
 | 
					from torchvision.transforms import Lambda
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import prototorch as pt
 | 
				
			||||||
 | 
					
 | 
				
			||||||
if __name__ == "__main__":
 | 
					if __name__ == "__main__":
 | 
				
			||||||
    # Command-line arguments
 | 
					    # Command-line arguments
 | 
				
			||||||
    parser = argparse.ArgumentParser()
 | 
					    parser = argparse.ArgumentParser()
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -1,10 +1,19 @@
 | 
				
			|||||||
import prototorch as pt
 | 
					"""Prototorch Data Modules
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					This allows to store the used dataset inside a Lightning Module.
 | 
				
			||||||
 | 
					Mainly used for PytorchLightningCLI configurations.
 | 
				
			||||||
 | 
					"""
 | 
				
			||||||
 | 
					from typing import Any, Optional, Type
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import pytorch_lightning as pl
 | 
					import pytorch_lightning as pl
 | 
				
			||||||
from torch.utils.data import DataLoader, random_split
 | 
					from torch.utils.data import DataLoader, Dataset, random_split
 | 
				
			||||||
from torchvision import transforms
 | 
					from torchvision import transforms
 | 
				
			||||||
from torchvision.datasets import MNIST
 | 
					from torchvision.datasets import MNIST
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import prototorch as pt
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# MNIST
 | 
				
			||||||
class MNISTDataModule(pl.LightningDataModule):
 | 
					class MNISTDataModule(pl.LightningDataModule):
 | 
				
			||||||
    def __init__(self, batch_size=32):
 | 
					    def __init__(self, batch_size=32):
 | 
				
			||||||
        super().__init__()
 | 
					        super().__init__()
 | 
				
			||||||
@@ -49,15 +58,67 @@ class MNISTDataModule(pl.LightningDataModule):
 | 
				
			|||||||
        return mnist_test
 | 
					        return mnist_test
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def train_on_mnist(batch_size=256) -> type:
 | 
					# def train_on_mnist(batch_size=256) -> type:
 | 
				
			||||||
    class DataClass(pl.LightningModule):
 | 
					#     class DataClass(pl.LightningModule):
 | 
				
			||||||
        datamodule = MNISTDataModule(batch_size=batch_size)
 | 
					#         datamodule = MNISTDataModule(batch_size=batch_size)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        def __init__(self, *args, **kwargs):
 | 
					#         def __init__(self, *args, **kwargs):
 | 
				
			||||||
            prototype_initializer = kwargs.pop(
 | 
					#             prototype_initializer = kwargs.pop(
 | 
				
			||||||
                "prototype_initializer", pt.components.Zeros((28, 28, 1)))
 | 
					#                 "prototype_initializer", pt.components.Zeros((28, 28, 1)))
 | 
				
			||||||
            super().__init__(*args,
 | 
					#             super().__init__(*args,
 | 
				
			||||||
                             prototype_initializer=prototype_initializer,
 | 
					#                              prototype_initializer=prototype_initializer,
 | 
				
			||||||
                             **kwargs)
 | 
					#                              **kwargs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    return DataClass
 | 
					#     dc: Type[DataClass] = DataClass
 | 
				
			||||||
 | 
					#     return dc
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# ABSTRACT
 | 
				
			||||||
 | 
					class GeneralDataModule(pl.LightningDataModule):
 | 
				
			||||||
 | 
					    def __init__(self, dataset: Dataset, batch_size: int = 32) -> None:
 | 
				
			||||||
 | 
					        super().__init__()
 | 
				
			||||||
 | 
					        self.train_dataset = dataset
 | 
				
			||||||
 | 
					        self.batch_size = batch_size
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def train_dataloader(self) -> DataLoader:
 | 
				
			||||||
 | 
					        return DataLoader(self.train_dataset, batch_size=self.batch_size)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# def train_on_dataset(dataset: Dataset, batch_size: int = 256):
 | 
				
			||||||
 | 
					#     class DataClass(pl.LightningModule):
 | 
				
			||||||
 | 
					#         datamodule = GeneralDataModule(dataset, batch_size)
 | 
				
			||||||
 | 
					#         datashape = dataset[0][0].shape
 | 
				
			||||||
 | 
					#         example_input_array = torch.zeros_like(dataset[0][0]).unsqueeze(0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#         def __init__(self, *args: Any, **kwargs: Any) -> None:
 | 
				
			||||||
 | 
					#             prototype_initializer = kwargs.pop(
 | 
				
			||||||
 | 
					#                 "prototype_initializer",
 | 
				
			||||||
 | 
					#                 pt.components.Zeros(self.datashape),
 | 
				
			||||||
 | 
					#             )
 | 
				
			||||||
 | 
					#             super().__init__(*args,
 | 
				
			||||||
 | 
					#                              prototype_initializer=prototype_initializer,
 | 
				
			||||||
 | 
					#                              **kwargs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#     return DataClass
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# if __name__ == "__main__":
 | 
				
			||||||
 | 
					#     from prototorch.models import GLVQ
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#     demo_dataset = pt.datasets.Iris()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#     TrainingClass: Type = train_on_dataset(demo_dataset)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#     class DemoGLVQ(TrainingClass, GLVQ):
 | 
				
			||||||
 | 
					#         """Model Definition."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#     # Hyperparameters
 | 
				
			||||||
 | 
					#     hparams = dict(
 | 
				
			||||||
 | 
					#         distribution={
 | 
				
			||||||
 | 
					#             "num_classes": 3,
 | 
				
			||||||
 | 
					#             "prototypes_per_class": 4
 | 
				
			||||||
 | 
					#         },
 | 
				
			||||||
 | 
					#         lr=0.01,
 | 
				
			||||||
 | 
					#     )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#     initialized = DemoGLVQ(hparams)
 | 
				
			||||||
 | 
					#     print(initialized)
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -3,11 +3,8 @@
 | 
				
			|||||||
import torch
 | 
					import torch
 | 
				
			||||||
from prototorch.functions.activations import get_activation
 | 
					from prototorch.functions.activations import get_activation
 | 
				
			||||||
from prototorch.functions.competitions import wtac
 | 
					from prototorch.functions.competitions import wtac
 | 
				
			||||||
from prototorch.functions.distances import (
 | 
					from prototorch.functions.distances import (lomega_distance, omega_distance,
 | 
				
			||||||
    lomega_distance,
 | 
					                                            squared_euclidean_distance)
 | 
				
			||||||
    omega_distance,
 | 
					 | 
				
			||||||
    squared_euclidean_distance,
 | 
					 | 
				
			||||||
)
 | 
					 | 
				
			||||||
from prototorch.functions.helper import get_flat
 | 
					from prototorch.functions.helper import get_flat
 | 
				
			||||||
from prototorch.functions.losses import glvq_loss, lvq1_loss, lvq21_loss
 | 
					from prototorch.functions.losses import glvq_loss, lvq1_loss, lvq21_loss
 | 
				
			||||||
from prototorch.modules import LambdaLayer, LossLayer
 | 
					from prototorch.modules import LambdaLayer, LossLayer
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										15
									
								
								setup.py
									
									
									
									
									
								
							
							
						
						
									
										15
									
								
								setup.py
									
									
									
									
									
								
							@@ -1,10 +1,12 @@
 | 
				
			|||||||
"""
 | 
					"""
 | 
				
			||||||
  _____           _     _______             _
 | 
					
 | 
				
			||||||
 |  __ \         | |   |__   __|           | |
 | 
					 ######
 | 
				
			||||||
 | |__) | __ ___ | |_ ___ | | ___  _ __ ___| |__
 | 
					 #     # #####   ####  #####  ####  #####  ####  #####   ####  #    #
 | 
				
			||||||
 |  ___/ '__/ _ \| __/ _ \| |/ _ \| '__/ __| '_ \
 | 
					 #     # #    # #    #   #   #    #   #   #    # #    # #    # #    #
 | 
				
			||||||
 | |   | | | (_) | || (_) | | (_) | | | (__| | | |
 | 
					 ######  #    # #    #   #   #    #   #   #    # #    # #      ######
 | 
				
			||||||
 |_|   |_|  \___/ \__\___/|_|\___/|_|  \___|_| |_|Plugin
 | 
					 #       #####  #    #   #   #    #   #   #    # #####  #      #    #
 | 
				
			||||||
 | 
					 #       #   #  #    #   #   #    #   #   #    # #   #  #    # #    #
 | 
				
			||||||
 | 
					 #       #    #  ####    #    ####    #    ####  #    #  ####  #    #Plugin
 | 
				
			||||||
 | 
					
 | 
				
			||||||
ProtoTorch models Plugin Package
 | 
					ProtoTorch models Plugin Package
 | 
				
			||||||
"""
 | 
					"""
 | 
				
			||||||
@@ -29,6 +31,7 @@ CLI = [
 | 
				
			|||||||
]
 | 
					]
 | 
				
			||||||
DEV = [
 | 
					DEV = [
 | 
				
			||||||
    "bumpversion",
 | 
					    "bumpversion",
 | 
				
			||||||
 | 
					    "pre-commit",
 | 
				
			||||||
]
 | 
					]
 | 
				
			||||||
DOCS = [
 | 
					DOCS = [
 | 
				
			||||||
    "recommonmark",
 | 
					    "recommonmark",
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user