9 Commits

Author SHA1 Message Date
Alexander Engelsberger
bc20acd63b Bump version: 0.4.1 → 0.4.2 2021-05-11 16:08:37 +02:00
Jensun Ravichandran
7bb93f027a Support for unequal prototype distributions 2021-05-11 16:11:11 +02:00
Alexander Engelsberger
a864cf5d4d Bump version: 0.4.0 → 0.4.1 2021-05-11 13:37:54 +02:00
Alexander Engelsberger
2175f524e8 Update bug report issues template. 2021-05-11 13:35:38 +02:00
Alexander Engelsberger
c1c21e92df Add LVQ 1 and LVQ 2.1 loss functions. 2021-05-11 13:25:10 +02:00
Alexander Engelsberger
2b676ee06e Fix travis.yml. 2021-05-10 17:15:05 +02:00
Jensun Ravichandran
dda2f1d779 Clean-up CI setup 2021-05-10 16:37:43 +02:00
Alexander Engelsberger
3a8388e24f Version 0.4.0 2021-05-10 15:13:58 +02:00
Alexander Engelsberger
fc7d64aaea Use Github Default Issue Templates 2021-05-04 11:20:15 +02:00
13 changed files with 153 additions and 60 deletions

View File

@@ -1,5 +1,5 @@
[bumpversion] [bumpversion]
current_version = 0.4.0 current_version = 0.4.2
commit = True commit = True
tag = True tag = True
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+) parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)

31
.github/ISSUE_TEMPLATE/bug_report.md vendored Normal file
View File

@@ -0,0 +1,31 @@
---
name: Bug report
about: Create a report to help us improve
title: ''
labels: ''
assignees: ''
---
**Describe the bug**
A clear and concise description of what the bug is.
**To Reproduce**
Steps to reproduce the behavior:
1. Install Prototorch by running '...'
2. Run script '...'
3. See errors
**Expected behavior**
A clear and concise description of what you expected to happen.
**Screenshots**
If applicable, add screenshots to help explain your problem.
**Desktop (please complete the following information):**
- OS: [e.g. Ubuntu 20.10]
- Prototorch Version: [e.g. v0.4.0]
- Python Version: [e.g. 3.9.5]
**Additional context**
Add any other context about the problem here.

View File

@@ -0,0 +1,20 @@
---
name: Feature request
about: Suggest an idea for this project
title: ''
labels: ''
assignees: ''
---
**Is your feature request related to a problem? Please describe.**
A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]
**Describe the solution you'd like**
A clear and concise description of what you want to happen.
**Describe alternatives you've considered**
A clear and concise description of any alternative solutions or features you've considered.
**Additional context**
Add any other context or screenshots about the feature request here.

View File

@@ -23,10 +23,7 @@ jobs:
- name: Install dependencies - name: Install dependencies
run: | run: |
python -m pip install --upgrade pip python -m pip install --upgrade pip
pip install . pip install .[all]
- name: Install extras
run: |
pip install -r requirements.txt
- name: Lint with flake8 - name: Lint with flake8
run: | run: |
pip install flake8 pip install flake8

View File

@@ -5,10 +5,8 @@ python: 3.8
cache: cache:
directories: directories:
- "./tests/artifacts" - "./tests/artifacts"
# - "$HOME/.prototorch/datasets"
install: install:
- pip install . --progress-bar off - pip install .[all] --progress-bar off
- pip install -r requirements.txt
# Generate code coverage report # Generate code coverage report
script: script:

View File

@@ -23,7 +23,7 @@ author = "Jensun Ravichandran"
# The full version, including alpha/beta/rc tags # The full version, including alpha/beta/rc tags
# #
release = "0.4.0" release = "0.4.2"
# -- General configuration --------------------------------------------------- # -- General configuration ---------------------------------------------------

View File

@@ -1,7 +1,7 @@
"""ProtoTorch package.""" """ProtoTorch package."""
# Core Setup # Core Setup
__version__ = "0.4.0" __version__ = "0.4.2"
__all_core__ = [ __all_core__ = [
"datasets", "datasets",

View File

@@ -4,8 +4,10 @@ import warnings
from typing import Tuple from typing import Tuple
import torch import torch
from prototorch.components.initializers import (ComponentsInitializer, from prototorch.components.initializers import (ClassAwareInitializer,
EqualLabelInitializer, ComponentsInitializer,
EqualLabelsInitializer,
UnequalLabelsInitializer,
ZeroReasoningsInitializer) ZeroReasoningsInitializer)
from prototorch.functions.initializers import get_initializer from prototorch.functions.initializers import get_initializer
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
@@ -30,12 +32,15 @@ class Components(torch.nn.Module):
else: else:
self._initialize_components(number_of_components, initializer) self._initialize_components(number_of_components, initializer)
def _initialize_components(self, number_of_components, initializer): def _precheck_initializer(self, initializer):
if not isinstance(initializer, ComponentsInitializer): if not isinstance(initializer, ComponentsInitializer):
emsg = f"`initializer` has to be some subtype of " \ emsg = f"`initializer` has to be some subtype of " \
f"{ComponentsInitializer}. " \ f"{ComponentsInitializer}. " \
f"You have provided: {initializer=} instead." f"You have provided: {initializer=} instead."
raise TypeError(emsg) raise TypeError(emsg)
def _initialize_components(self, number_of_components, initializer):
self._precheck_initializer(initializer)
self._components = Parameter( self._components = Parameter(
initializer.generate(number_of_components)) initializer.generate(number_of_components))
@@ -57,7 +62,7 @@ class LabeledComponents(Components):
Every Component has a label assigned. Every Component has a label assigned.
""" """
def __init__(self, def __init__(self,
labels=None, distribution=None,
initializer=None, initializer=None,
*, *,
initialized_components=None): initialized_components=None):
@@ -65,15 +70,27 @@ class LabeledComponents(Components):
super().__init__(initialized_components=initialized_components[0]) super().__init__(initialized_components=initialized_components[0])
self._labels = initialized_components[1] self._labels = initialized_components[1]
else: else:
self._initialize_labels(labels) self._initialize_labels(distribution)
super().__init__(number_of_components=len(self._labels), super().__init__(number_of_components=len(self._labels),
initializer=initializer) initializer=initializer)
def _initialize_labels(self, labels): def _initialize_components(self, number_of_components, initializer):
if type(labels) == tuple: if isinstance(initializer, ClassAwareInitializer):
num_classes, prototypes_per_class = labels self._precheck_initializer(initializer)
labels = EqualLabelInitializer(num_classes, prototypes_per_class) self._components = Parameter(
initializer.generate(number_of_components, self.distribution))
else:
super()._initialize_components(self, number_of_components,
initializer)
def _initialize_labels(self, distribution):
if type(distribution) == tuple:
num_classes, prototypes_per_class = distribution
labels = EqualLabelsInitializer(num_classes, prototypes_per_class)
elif type(distribution) == list:
labels = UnequalLabelsInitializer(distribution)
self.distribution = labels.distribution
self._labels = labels.generate() self._labels = labels.generate()
@property @property

View File

@@ -1,6 +1,7 @@
"""ProtoTroch Initializers.""" """ProtoTroch Initializers."""
import warnings import warnings
from collections.abc import Iterable from collections.abc import Iterable
from itertools import chain
import torch import torch
from torch.utils.data import DataLoader, Dataset from torch.utils.data import DataLoader, Dataset
@@ -91,6 +92,15 @@ class ClassAwareInitializer(ComponentsInitializer):
self.clabels = torch.unique(self.labels) self.clabels = torch.unique(self.labels)
self.num_classes = len(self.clabels) self.num_classes = len(self.clabels)
def _get_samples_from_initializer(self, length, dist):
if not dist:
per_class = length // self.num_classes
dist = self.num_classes * [per_class]
samples_list = [
init.generate(n) for init, n in zip(self.initializers, dist)
]
return torch.vstack(samples_list)
class StratifiedMeanInitializer(ClassAwareInitializer): class StratifiedMeanInitializer(ClassAwareInitializer):
def __init__(self, arg): def __init__(self, arg):
@@ -102,10 +112,9 @@ class StratifiedMeanInitializer(ClassAwareInitializer):
class_initializer = MeanInitializer(class_data) class_initializer = MeanInitializer(class_data)
self.initializers.append(class_initializer) self.initializers.append(class_initializer)
def generate(self, length): def generate(self, length, dist=[]):
per_class = length // self.num_classes samples = self._get_samples_from_initializer(length, dist)
samples_list = [init.generate(per_class) for init in self.initializers] return samples
return torch.vstack(samples_list)
class StratifiedSelectionInitializer(ClassAwareInitializer): class StratifiedSelectionInitializer(ClassAwareInitializer):
@@ -126,10 +135,8 @@ class StratifiedSelectionInitializer(ClassAwareInitializer):
mask = torch.bernoulli(n1) - torch.bernoulli(n2) mask = torch.bernoulli(n1) - torch.bernoulli(n2)
return x + (self.noise * mask) return x + (self.noise * mask)
def generate(self, length): def generate(self, length, dist=[]):
per_class = length // self.num_classes samples = self._get_samples_from_initializer(length, dist)
samples_list = [init.generate(per_class) for init in self.initializers]
samples = torch.vstack(samples_list)
if self.noise is not None: if self.noise is not None:
# samples = self.add_noise(samples) # samples = self.add_noise(samples)
samples = samples + self.noise samples = samples + self.noise
@@ -142,11 +149,29 @@ class LabelsInitializer:
raise NotImplementedError("Subclasses should implement this!") raise NotImplementedError("Subclasses should implement this!")
class EqualLabelInitializer(LabelsInitializer): class UnequalLabelsInitializer(LabelsInitializer):
def __init__(self, dist):
self.dist = dist
@property
def distribution(self):
return self.dist
def generate(self):
clabels = range(len(self.dist))
labels = list(chain(*[[i] * n for i, n in zip(clabels, self.dist)]))
return torch.tensor(labels)
class EqualLabelsInitializer(LabelsInitializer):
def __init__(self, classes, per_class): def __init__(self, classes, per_class):
self.classes = classes self.classes = classes
self.per_class = per_class self.per_class = per_class
@property
def distribution(self):
return self.classes * [self.per_class]
def generate(self): def generate(self):
return torch.arange(self.classes).repeat(self.per_class, 1).T.flatten() return torch.arange(self.classes).repeat(self.per_class, 1).T.flatten()

View File

@@ -31,3 +31,26 @@ def glvq_loss(distances, target_labels, prototype_labels):
dp, dm = _get_dp_dm(distances, target_labels, prototype_labels) dp, dm = _get_dp_dm(distances, target_labels, prototype_labels)
mu = (dp - dm) / (dp + dm) mu = (dp - dm) / (dp + dm)
return mu return mu
def lvq1_loss(distances, target_labels, prototype_labels):
"""LVQ1 loss function with support for one-hot labels.
See Section 4 [Sado&Yamada]
https://papers.nips.cc/paper/1995/file/9c3b1830513cc3b8fc4b76635d32e692-Paper.pdf
"""
dp, dm = _get_dp_dm(distances, target_labels, prototype_labels)
mu = dp
mu[dp > dm] = -dm[dp > dm]
return mu
def lvq21_loss(distances, target_labels, prototype_labels):
"""LVQ2.1 loss function with support for one-hot labels.
See Section 4 [Sado&Yamada]
https://papers.nips.cc/paper/1995/file/9c3b1830513cc3b8fc4b76635d32e692-Paper.pdf
"""
dp, dm = _get_dp_dm(distances, target_labels, prototype_labels)
mu = dp - dm
return mu

View File

@@ -1,5 +0,0 @@
matplotlib==3.1.2
pytest==5.3.4
requests==2.22.0
codecov==2.0.22
tqdm==4.44.1

View File

@@ -21,27 +21,28 @@ INSTALL_REQUIRES = [
"torchvision>=0.5.0", "torchvision>=0.5.0",
"numpy>=1.9.1", "numpy>=1.9.1",
] ]
DATASETS = [
"requests",
"tqdm",
]
DEV = ["bumpversion"]
DOCS = [ DOCS = [
"recommonmark", "recommonmark",
"sphinx", "sphinx",
"sphinx_rtd_theme", "sphinx_rtd_theme",
"sphinxcontrib-katex", "sphinxcontrib-katex",
] ]
DATASETS = [
"requests",
"tqdm",
]
EXAMPLES = [ EXAMPLES = [
"sklearn", "sklearn",
"matplotlib", "matplotlib",
"torchinfo", "torchinfo",
] ]
TESTS = ["pytest"] TESTS = ["codecov", "pytest"]
ALL = DOCS + DATASETS + EXAMPLES + TESTS ALL = DATASETS + DEV + DOCS + EXAMPLES + TESTS
setup( setup(
name="prototorch", name="prototorch",
version="0.4.0", version="0.4.2",
description="Highly extensible, GPU-supported " description="Highly extensible, GPU-supported "
"Learning Vector Quantization (LVQ) toolbox " "Learning Vector Quantization (LVQ) toolbox "
"built using PyTorch and its nn API.", "built using PyTorch and its nn API.",
@@ -71,6 +72,7 @@ setup(
"Programming Language :: Python :: 3.6", "Programming Language :: Python :: 3.6",
"Programming Language :: Python :: 3.7", "Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Operating System :: OS Independent", "Operating System :: OS Independent",
"Topic :: Scientific/Engineering :: Artificial Intelligence", "Topic :: Scientific/Engineering :: Artificial Intelligence",
"Topic :: Software Development :: Libraries", "Topic :: Software Development :: Libraries",

15
tox.ini
View File

@@ -1,15 +0,0 @@
# tox (https://tox.readthedocs.io/) is a tool for running tests
# in multiple virtualenvs. This configuration file will run the
# test suite on all supported python versions. To use it, "pip install tox"
# and then run "tox" from this directory.
[tox]
envlist = py36,py37,py38
[testenv]
deps =
pytest
coverage
commands =
pip install -e .
coverage run -m pytest