Compare commits
54 Commits
v0.1.0-dev
...
v0.1.1-rc0
Author | SHA1 | Date | |
---|---|---|---|
|
0cfbc0473b | ||
|
cf0659d881 | ||
|
d17b9a3346 | ||
|
532f63b1de | ||
|
c11a3860df | ||
|
dab91e471a | ||
|
a167565857 | ||
|
e063625486 | ||
|
89eb5358a0 | ||
|
5c59515128 | ||
|
7eb7a6b194 | ||
|
5811c4b9f9 | ||
|
7b1887d56e | ||
|
63a25e7a38 | ||
|
a0f20a40f6 | ||
|
88cbe0a126 | ||
|
a3548e0ddd | ||
|
3cfbc49254 | ||
|
2b82830590 | ||
|
553b1e1a65 | ||
|
a9d2855323 | ||
|
cf7d7b5d9d | ||
|
a22c752342 | ||
|
4158586cb9 | ||
|
f80d9648c3 | ||
|
e54bf07030 | ||
|
8c629c0cb1 | ||
|
8f3a43f62a | ||
|
955661af95 | ||
|
c54d14c55e | ||
|
6090aad176 | ||
|
1ec7bd261b | ||
|
da3b0cc262 | ||
|
f640a22cf2 | ||
|
c843ace63d | ||
|
242c9de3b6 | ||
|
438a5b9360 | ||
|
f98f3d095e | ||
|
21b0279839 | ||
|
b19cbcb76a | ||
|
7d5ab81dbf | ||
|
bde408a80e | ||
|
900955d67a | ||
|
3757c937b3 | ||
|
38f637aaeb | ||
|
6ddfe48a95 | ||
|
bf0e694321 | ||
|
e2c9848120 | ||
|
dc60b7e5b5 | ||
|
c21913fdd4 | ||
|
59e31f94ab | ||
|
cddefa9b0d | ||
|
26d71fdd60 | ||
|
ced8f532dd |
@@ -1,5 +1,5 @@
|
|||||||
[bumpversion]
|
[bumpversion]
|
||||||
current_version = 0.1.0-dev0
|
current_version = 0.1.1-rc0
|
||||||
commit = True
|
commit = True
|
||||||
tag = True
|
tag = True
|
||||||
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)(\-(?P<release>[a-z]+)(?P<build>\d+))?
|
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)(\-(?P<release>[a-z]+)(?P<build>\d+))?
|
||||||
|
5
.github/workflows/pythonapp.yml
vendored
5
.github/workflows/pythonapp.yml
vendored
@@ -1,7 +1,7 @@
|
|||||||
# This workflow will install Python dependencies, run tests and lint with a single version of Python
|
# This workflow will install Python dependencies, run tests and lint with a single version of Python
|
||||||
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
|
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
|
||||||
|
|
||||||
name: Tests
|
name: tests
|
||||||
|
|
||||||
on:
|
on:
|
||||||
push:
|
push:
|
||||||
@@ -24,6 +24,9 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
python -m pip install --upgrade pip
|
python -m pip install --upgrade pip
|
||||||
pip install .
|
pip install .
|
||||||
|
- name: Install extras
|
||||||
|
run: |
|
||||||
|
pip install -r requirements.txt
|
||||||
- name: Lint with flake8
|
- name: Lint with flake8
|
||||||
run: |
|
run: |
|
||||||
pip install flake8
|
pip install flake8
|
||||||
|
18
.travis.yml
Normal file
18
.travis.yml
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
dist: bionic
|
||||||
|
sudo: false
|
||||||
|
language: python
|
||||||
|
python: 3.8
|
||||||
|
cache:
|
||||||
|
directories:
|
||||||
|
- ./tests/artifacts
|
||||||
|
|
||||||
|
install:
|
||||||
|
- pip install . --progress-bar off
|
||||||
|
- pip install -r requirements.txt
|
||||||
|
|
||||||
|
script:
|
||||||
|
- coverage run -m pytest
|
||||||
|
|
||||||
|
# Push the results to codecov
|
||||||
|
after_success:
|
||||||
|
- bash <(curl -s https://codecov.io/bash)
|
@@ -1,9 +1,11 @@
|
|||||||
include .bumpversion.cfg
|
include .bumpversion.cfg
|
||||||
include LICENSE
|
include LICENSE
|
||||||
include tox.ini
|
include tox.ini
|
||||||
|
include *.yml
|
||||||
recursive-include docs *.bat
|
recursive-include docs *.bat
|
||||||
recursive-include docs *.png
|
recursive-include docs *.png
|
||||||
recursive-include docs *.py
|
recursive-include docs *.py
|
||||||
recursive-include docs *.rst
|
recursive-include docs *.rst
|
||||||
recursive-include docs Makefile
|
recursive-include docs Makefile
|
||||||
recursive-include examples *.py
|
recursive-include examples *.py
|
||||||
|
recursive-include tests *.py
|
||||||
|
25
README.md
25
README.md
@@ -3,8 +3,14 @@
|
|||||||
ProtoTorch is a PyTorch-based Python toolbox for bleeding-edge research in
|
ProtoTorch is a PyTorch-based Python toolbox for bleeding-edge research in
|
||||||
prototype-based machine learning algorithms.
|
prototype-based machine learning algorithms.
|
||||||
|
|
||||||

|
[](https://travis-ci.org/si-cim/prototorch)
|
||||||
|

|
||||||
|
[](https://github.com/si-cim/prototorch/releases)
|
||||||
|
[](https://pypi.org/project/prototorch/)
|
||||||
[](https://codecov.io/gh/si-cim/prototorch)
|
[](https://codecov.io/gh/si-cim/prototorch)
|
||||||
|
[](https://www.codacy.com/gh/si-cim/prototorch?utm_source=github.com&utm_medium=referral&utm_content=si-cim/prototorch&utm_campaign=Badge_Grade)
|
||||||
|

|
||||||
|
[](https://github.com/si-cim/prototorch/blob/master/LICENSE)
|
||||||
|
|
||||||
## Description
|
## Description
|
||||||
|
|
||||||
@@ -22,12 +28,12 @@ provided by PyTorch.
|
|||||||
## Installation
|
## Installation
|
||||||
|
|
||||||
ProtoTorch can be installed using `pip`.
|
ProtoTorch can be installed using `pip`.
|
||||||
```
|
```bash
|
||||||
pip install prototorch
|
pip install prototorch
|
||||||
```
|
```
|
||||||
|
|
||||||
To install the bleeding-edge features and improvements:
|
To install the bleeding-edge features and improvements:
|
||||||
```
|
```bash
|
||||||
git clone https://github.com/si-cim/prototorch.git
|
git clone https://github.com/si-cim/prototorch.git
|
||||||
git checkout dev
|
git checkout dev
|
||||||
cd prototorch
|
cd prototorch
|
||||||
@@ -47,3 +53,16 @@ API, with more algorithms and techniques coming soon. If you would simply like
|
|||||||
to be able to use those algorithms to train large ML models on a GPU, ProtoTorch
|
to be able to use those algorithms to train large ML models on a GPU, ProtoTorch
|
||||||
lets you do this without requiring a black-belt in high-performance Tensor
|
lets you do this without requiring a black-belt in high-performance Tensor
|
||||||
computation.
|
computation.
|
||||||
|
|
||||||
|
## Bibtex
|
||||||
|
|
||||||
|
If you would like to cite the package, please use this:
|
||||||
|
```bibtex
|
||||||
|
@misc{Ravichandran2020,
|
||||||
|
author = {Ravichandran, J},
|
||||||
|
title = {ProtoTorch},
|
||||||
|
year = {2020},
|
||||||
|
publisher = {GitHub},
|
||||||
|
journal = {GitHub repository},
|
||||||
|
howpublished = {\url{https://github.com/si-cim/prototorch}}
|
||||||
|
}
|
||||||
|
11
RELEASE.md
Normal file
11
RELEASE.md
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
# ProtoTorch Releases
|
||||||
|
|
||||||
|
## Release 0.1.1-dev0
|
||||||
|
|
||||||
|
### Includes
|
||||||
|
- Minor bugfixes.
|
||||||
|
- 100% line coverage.
|
||||||
|
|
||||||
|
## Release 0.1.0-dev0
|
||||||
|
|
||||||
|
Initial public release of ProtoTorch.
|
@@ -1,4 +1,4 @@
|
|||||||
"""ProtoTorch GLVQ example using 2D Iris data"""
|
"""ProtoTorch GLVQ example using 2D Iris data."""
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@@ -8,7 +8,7 @@ from sklearn.preprocessing import StandardScaler
|
|||||||
|
|
||||||
from prototorch.functions.distances import euclidean_distance
|
from prototorch.functions.distances import euclidean_distance
|
||||||
from prototorch.modules.losses import GLVQLoss
|
from prototorch.modules.losses import GLVQLoss
|
||||||
from prototorch.modules.prototypes import AddPrototypes1D
|
from prototorch.modules.prototypes import Prototypes1D
|
||||||
|
|
||||||
# Prepare and preprocess the data
|
# Prepare and preprocess the data
|
||||||
scaler = StandardScaler()
|
scaler = StandardScaler()
|
||||||
@@ -21,11 +21,12 @@ x_train = scaler.transform(x_train)
|
|||||||
# Define the GLVQ model
|
# Define the GLVQ model
|
||||||
class Model(torch.nn.Module):
|
class Model(torch.nn.Module):
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
|
"""GLVQ model."""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.p1 = AddPrototypes1D(input_dim=2,
|
self.p1 = Prototypes1D(input_dim=2,
|
||||||
prototypes_per_class=1,
|
prototypes_per_class=1,
|
||||||
nclasses=3,
|
nclasses=3,
|
||||||
prototype_initializer='zeros')
|
prototype_initializer='zeros')
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
protos = self.p1.prototypes
|
protos = self.p1.prototypes
|
||||||
@@ -41,13 +42,17 @@ model = Model()
|
|||||||
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
|
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
|
||||||
criterion = GLVQLoss(squashing='sigmoid_beta', beta=10)
|
criterion = GLVQLoss(squashing='sigmoid_beta', beta=10)
|
||||||
|
|
||||||
|
x_in = torch.Tensor(x_train)
|
||||||
|
y_in = torch.Tensor(y_train)
|
||||||
|
|
||||||
# Training loop
|
# Training loop
|
||||||
fig = plt.figure('Prototype Visualization')
|
title = 'Prototype Visualization'
|
||||||
|
fig = plt.figure(title)
|
||||||
for epoch in range(70):
|
for epoch in range(70):
|
||||||
# Compute loss.
|
# Compute loss
|
||||||
distances, plabels = model(torch.tensor(x_train))
|
dis, plabels = model(x_in)
|
||||||
loss = criterion([distances, plabels], torch.tensor(y_train))
|
loss = criterion([dis, plabels], y_in)
|
||||||
print(f'Epoch: {epoch + 1:03d} Loss: {loss.item():02.02f}')
|
print(f'Epoch: {epoch + 1:03d} Loss: {loss.item():05.02f}')
|
||||||
|
|
||||||
# Take a gradient descent step
|
# Take a gradient descent step
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
@@ -60,6 +65,9 @@ for epoch in range(70):
|
|||||||
# Visualize the data and the prototypes
|
# Visualize the data and the prototypes
|
||||||
ax = fig.gca()
|
ax = fig.gca()
|
||||||
ax.cla()
|
ax.cla()
|
||||||
|
ax.set_title(title)
|
||||||
|
ax.set_xlabel('Data dimension 1')
|
||||||
|
ax.set_ylabel('Data dimension 2')
|
||||||
cmap = 'viridis'
|
cmap = 'viridis'
|
||||||
ax.scatter(x_train[:, 0], x_train[:, 1], c=y_train, edgecolor='k')
|
ax.scatter(x_train[:, 0], x_train[:, 1], c=y_train, edgecolor='k')
|
||||||
ax.scatter(protos[:, 0],
|
ax.scatter(protos[:, 0],
|
||||||
@@ -71,28 +79,17 @@ for epoch in range(70):
|
|||||||
s=50)
|
s=50)
|
||||||
|
|
||||||
# Paint decision regions
|
# Paint decision regions
|
||||||
border = 1
|
|
||||||
resolution = 50
|
|
||||||
x = np.vstack((x_train, protos))
|
x = np.vstack((x_train, protos))
|
||||||
x_min, x_max = x[:, 0].min(), x[:, 0].max()
|
x_min, x_max = x[:, 0].min() - 1, x[:, 0].max() + 1
|
||||||
y_min, y_max = x[:, 1].min(), x[:, 1].max()
|
y_min, y_max = x[:, 1].min() - 1, x[:, 1].max() + 1
|
||||||
x_min, x_max = x_min - border, x_max + border
|
xx, yy = np.meshgrid(np.arange(x_min, x_max, 1 / 50),
|
||||||
y_min, y_max = y_min - border, y_max + border
|
np.arange(y_min, y_max, 1 / 50))
|
||||||
try:
|
|
||||||
xx, yy = np.meshgrid(np.arange(x_min, x_max, 1.0 / resolution),
|
|
||||||
np.arange(y_min, y_max, 1.0 / resolution))
|
|
||||||
except ValueError as ve:
|
|
||||||
print(ve)
|
|
||||||
raise ValueError(f'x_min: {x_min}, x_max: {x_max}. '
|
|
||||||
f'x_min - x_max is {x_max - x_min}.')
|
|
||||||
except MemoryError as me:
|
|
||||||
print(me)
|
|
||||||
raise ValueError('Too many points. ' 'Try reducing the resolution.')
|
|
||||||
mesh_input = np.c_[xx.ravel(), yy.ravel()]
|
mesh_input = np.c_[xx.ravel(), yy.ravel()]
|
||||||
|
|
||||||
torch_input = torch.from_numpy(mesh_input)
|
torch_input = torch.Tensor(mesh_input)
|
||||||
d = model(torch_input)[0]
|
d = model(torch_input)[0]
|
||||||
y_pred = np.argmin(d.detach().numpy(), axis=1)
|
y_pred = np.argmin(d.detach().numpy(),
|
||||||
|
axis=1) # assume one prototype per class
|
||||||
y_pred = y_pred.reshape(xx.shape)
|
y_pred = y_pred.reshape(xx.shape)
|
||||||
|
|
||||||
# Plot voronoi regions
|
# Plot voronoi regions
|
||||||
@@ -100,4 +97,5 @@ for epoch in range(70):
|
|||||||
|
|
||||||
ax.set_xlim(left=x_min + 0, right=x_max - 0)
|
ax.set_xlim(left=x_min + 0, right=x_max - 0)
|
||||||
ax.set_ylim(bottom=y_min + 0, top=y_max - 0)
|
ax.set_ylim(bottom=y_min + 0, top=y_max - 0)
|
||||||
|
|
||||||
plt.pause(0.1)
|
plt.pause(0.1)
|
||||||
|
@@ -1 +1,11 @@
|
|||||||
__version__ = '0.1.0-dev0'
|
"""ProtoTorch package."""
|
||||||
|
|
||||||
|
__version__ = '0.1.1-rc0'
|
||||||
|
|
||||||
|
from prototorch import datasets, functions, modules
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
'datasets',
|
||||||
|
'functions',
|
||||||
|
'modules',
|
||||||
|
]
|
||||||
|
7
prototorch/datasets/__init__.py
Normal file
7
prototorch/datasets/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
"""ProtoTorch datasets."""
|
||||||
|
|
||||||
|
from .tecator import Tecator
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
'Tecator',
|
||||||
|
]
|
87
prototorch/datasets/abstract.py
Normal file
87
prototorch/datasets/abstract.py
Normal file
@@ -0,0 +1,87 @@
|
|||||||
|
"""ProtoTorch abstract dataset classes.
|
||||||
|
|
||||||
|
Based on `torchvision.VisionDataset` and `torchvision.MNIST`
|
||||||
|
|
||||||
|
For the original code, see:
|
||||||
|
https://github.com/pytorch/vision/blob/master/torchvision/datasets/vision.py
|
||||||
|
https://github.com/pytorch/vision/blob/master/torchvision/datasets/mnist.py
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class Dataset(torch.utils.data.Dataset):
|
||||||
|
"""Abstract dataset class to be inherited."""
|
||||||
|
_repr_indent = 2
|
||||||
|
|
||||||
|
def __init__(self, root):
|
||||||
|
if isinstance(root, torch._six.string_classes):
|
||||||
|
root = os.path.expanduser(root)
|
||||||
|
self.root = root
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class ProtoDataset(Dataset):
|
||||||
|
"""Abstract dataset class to be inherited."""
|
||||||
|
training_file = 'training.pt'
|
||||||
|
test_file = 'test.pt'
|
||||||
|
|
||||||
|
def __init__(self, root, train=True, download=True, verbose=True):
|
||||||
|
super().__init__(root)
|
||||||
|
self.train = train # training set or test set
|
||||||
|
self.verbose = verbose
|
||||||
|
|
||||||
|
if download:
|
||||||
|
self.download()
|
||||||
|
|
||||||
|
if not self._check_exists():
|
||||||
|
raise RuntimeError('Dataset not found. '
|
||||||
|
'You can use download=True to download it')
|
||||||
|
|
||||||
|
data_file = self.training_file if self.train else self.test_file
|
||||||
|
|
||||||
|
self.data, self.targets = torch.load(
|
||||||
|
os.path.join(self.processed_folder, data_file))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def raw_folder(self):
|
||||||
|
return os.path.join(self.root, self.__class__.__name__, 'raw')
|
||||||
|
|
||||||
|
@property
|
||||||
|
def processed_folder(self):
|
||||||
|
return os.path.join(self.root, self.__class__.__name__, 'processed')
|
||||||
|
|
||||||
|
@property
|
||||||
|
def class_to_idx(self):
|
||||||
|
return {_class: i for i, _class in enumerate(self.classes)}
|
||||||
|
|
||||||
|
def _check_exists(self):
|
||||||
|
return (os.path.exists(
|
||||||
|
os.path.join(self.processed_folder, self.training_file))
|
||||||
|
and os.path.exists(
|
||||||
|
os.path.join(self.processed_folder, self.test_file)))
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
head = 'Dataset ' + self.__class__.__name__
|
||||||
|
body = ['Number of datapoints: {}'.format(self.__len__())]
|
||||||
|
if self.root is not None:
|
||||||
|
body.append('Root location: {}'.format(self.root))
|
||||||
|
body += self.extra_repr().splitlines()
|
||||||
|
lines = [head] + [' ' * self._repr_indent + line for line in body]
|
||||||
|
return '\n'.join(lines)
|
||||||
|
|
||||||
|
def extra_repr(self):
|
||||||
|
return f"Split: {'Train' if self.train is True else 'Test'}"
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.data)
|
||||||
|
|
||||||
|
def download(self):
|
||||||
|
raise NotImplementedError
|
102
prototorch/datasets/tecator.py
Normal file
102
prototorch/datasets/tecator.py
Normal file
@@ -0,0 +1,102 @@
|
|||||||
|
"""Tecator dataset for classification.
|
||||||
|
|
||||||
|
URL:
|
||||||
|
http://lib.stat.cmu.edu/datasets/tecator
|
||||||
|
|
||||||
|
LICENCE / TERMS / COPYRIGHT:
|
||||||
|
This is the Tecator data set: The task is to predict the fat content
|
||||||
|
of a meat sample on the basis of its near infrared absorbance spectrum.
|
||||||
|
-------------------------------------------------------------------------
|
||||||
|
1. Statement of permission from Tecator (the original data source)
|
||||||
|
|
||||||
|
These data are recorded on a Tecator Infratec Food and Feed Analyzer
|
||||||
|
working in the wavelength range 850 - 1050 nm by the Near Infrared
|
||||||
|
Transmission (NIT) principle. Each sample contains finely chopped pure
|
||||||
|
meat with different moisture, fat and protein contents.
|
||||||
|
|
||||||
|
If results from these data are used in a publication we want you to
|
||||||
|
mention the instrument and company name (Tecator) in the publication.
|
||||||
|
In addition, please send a preprint of your article to
|
||||||
|
|
||||||
|
Karin Thente, Tecator AB,
|
||||||
|
Box 70, S-263 21 Hoganas, Sweden
|
||||||
|
|
||||||
|
The data are available in the public domain with no responsability from
|
||||||
|
the original data source. The data can be redistributed as long as this
|
||||||
|
permission note is attached.
|
||||||
|
|
||||||
|
For more information about the instrument - call Perstorp Analytical's
|
||||||
|
representative in your area.
|
||||||
|
|
||||||
|
Description:
|
||||||
|
For each meat sample the data consists of a 100 channel spectrum of
|
||||||
|
absorbances and the contents of moisture (water), fat and protein.
|
||||||
|
The absorbance is -log10 of the transmittance
|
||||||
|
measured by the spectrometer. The three contents, measured in percent,
|
||||||
|
are determined by analytic chemistry.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from torchvision.datasets.utils import download_file_from_google_drive
|
||||||
|
|
||||||
|
from prototorch.datasets.abstract import ProtoDataset
|
||||||
|
|
||||||
|
|
||||||
|
class Tecator(ProtoDataset):
|
||||||
|
"""Tecator dataset for classification."""
|
||||||
|
resources = [
|
||||||
|
('1MMuUK8V41IgNpnPDbg3E-QAL6wlErTk0',
|
||||||
|
'ba5607c580d0f91bb27dc29d13c2f8df'),
|
||||||
|
] # (google_storage_id, md5hash)
|
||||||
|
classes = ['0 - low_fat', '1 - high_fat']
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
img, target = self.data[index], int(self.targets[index])
|
||||||
|
return img, target
|
||||||
|
|
||||||
|
def download(self):
|
||||||
|
"""Download the data if it doesn't exist in already."""
|
||||||
|
if self._check_exists():
|
||||||
|
return
|
||||||
|
|
||||||
|
if self.verbose:
|
||||||
|
print('Making directories...')
|
||||||
|
os.makedirs(self.raw_folder, exist_ok=True)
|
||||||
|
os.makedirs(self.processed_folder, exist_ok=True)
|
||||||
|
|
||||||
|
if self.verbose:
|
||||||
|
print('Downloading...')
|
||||||
|
for fileid, md5 in self.resources:
|
||||||
|
filename = 'tecator.npz'
|
||||||
|
download_file_from_google_drive(fileid,
|
||||||
|
root=self.raw_folder,
|
||||||
|
filename=filename,
|
||||||
|
md5=md5)
|
||||||
|
|
||||||
|
if self.verbose:
|
||||||
|
print('Processing...')
|
||||||
|
with np.load(os.path.join(self.raw_folder, 'tecator.npz'),
|
||||||
|
allow_pickle=False) as f:
|
||||||
|
x_train, y_train = f['x_train'], f['y_train']
|
||||||
|
x_test, y_test = f['x_test'], f['y_test']
|
||||||
|
training_set = [
|
||||||
|
torch.tensor(x_train, dtype=torch.float32),
|
||||||
|
torch.tensor(y_train),
|
||||||
|
]
|
||||||
|
test_set = [
|
||||||
|
torch.tensor(x_test, dtype=torch.float32),
|
||||||
|
torch.tensor(y_test),
|
||||||
|
]
|
||||||
|
|
||||||
|
with open(os.path.join(self.processed_folder, self.training_file),
|
||||||
|
'wb') as f:
|
||||||
|
torch.save(training_set, f)
|
||||||
|
with open(os.path.join(self.processed_folder, self.test_file),
|
||||||
|
'wb') as f:
|
||||||
|
torch.save(test_set, f)
|
||||||
|
|
||||||
|
if self.verbose:
|
||||||
|
print('Done!')
|
@@ -0,0 +1,12 @@
|
|||||||
|
"""ProtoTorch functions."""
|
||||||
|
|
||||||
|
from .activations import identity, sigmoid_beta, swish_beta
|
||||||
|
from .competitions import knnc, wtac
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
'identity',
|
||||||
|
'sigmoid_beta',
|
||||||
|
'swish_beta',
|
||||||
|
'knnc',
|
||||||
|
'wtac',
|
||||||
|
]
|
||||||
|
@@ -5,44 +5,60 @@ import torch
|
|||||||
ACTIVATIONS = dict()
|
ACTIVATIONS = dict()
|
||||||
|
|
||||||
|
|
||||||
def register_activation(func):
|
# def register_activation(scriptf):
|
||||||
ACTIVATIONS[func.__name__] = func
|
# ACTIVATIONS[scriptf.name] = scriptf
|
||||||
return func
|
# return scriptf
|
||||||
|
def register_activation(function):
|
||||||
|
"""Add the activation function to the registry."""
|
||||||
|
ACTIVATIONS[function.__name__] = function
|
||||||
|
return function
|
||||||
|
|
||||||
|
|
||||||
@register_activation
|
@register_activation
|
||||||
def identity(input, **kwargs):
|
# @torch.jit.script
|
||||||
""":math:`f(x) = x`"""
|
def identity(x, beta=torch.tensor(0)):
|
||||||
return input
|
"""Identity activation function.
|
||||||
|
|
||||||
|
Definition:
|
||||||
|
:math:`f(x) = x`
|
||||||
|
"""
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
@register_activation
|
@register_activation
|
||||||
def sigmoid_beta(input, beta=10):
|
# @torch.jit.script
|
||||||
""":math:`f(x) = \\frac{1}{1 + e^{-\\beta x}}`
|
def sigmoid_beta(x, beta=torch.tensor(10)):
|
||||||
|
r"""Sigmoid activation function with scaling.
|
||||||
|
|
||||||
|
Definition:
|
||||||
|
:math:`f(x) = \frac{1}{1 + e^{-\beta x}}`
|
||||||
|
|
||||||
Keyword Arguments:
|
Keyword Arguments:
|
||||||
beta (float): Parameter :math:`\\beta`
|
beta (`torch.tensor`): Scaling parameter :math:`\beta`
|
||||||
"""
|
"""
|
||||||
out = torch.reciprocal(1.0 + torch.exp(-beta * input))
|
out = torch.reciprocal(1.0 + torch.exp(-int(beta.item()) * x))
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
@register_activation
|
@register_activation
|
||||||
def swish_beta(input, beta=10):
|
# @torch.jit.script
|
||||||
""":math:`f(x) = \\frac{x}{1 + e^{-\\beta x}}`
|
def swish_beta(x, beta=torch.tensor(10)):
|
||||||
|
r"""Swish activation function with scaling.
|
||||||
|
|
||||||
|
Definition:
|
||||||
|
:math:`f(x) = \frac{x}{1 + e^{-\beta x}}`
|
||||||
|
|
||||||
Keyword Arguments:
|
Keyword Arguments:
|
||||||
beta (float): Parameter :math:`\\beta`
|
beta (`torch.tensor`): Scaling parameter :math:`\beta`
|
||||||
"""
|
"""
|
||||||
out = input * sigmoid_beta(input, beta=beta)
|
out = x * sigmoid_beta(x, beta=beta)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
def get_activation(funcname):
|
def get_activation(funcname):
|
||||||
|
"""Deserialize the activation function."""
|
||||||
if callable(funcname):
|
if callable(funcname):
|
||||||
return funcname
|
return funcname
|
||||||
else:
|
if funcname in ACTIVATIONS:
|
||||||
if funcname in ACTIVATIONS:
|
return ACTIVATIONS.get(funcname)
|
||||||
return ACTIVATIONS.get(funcname)
|
raise NameError(f'Activation {funcname} was not found.')
|
||||||
else:
|
|
||||||
raise NameError(f'Activation {funcname} was not found.')
|
|
||||||
|
@@ -3,13 +3,43 @@
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
# @torch.jit.script
|
||||||
|
def stratified_min(distances, labels):
|
||||||
|
clabels = torch.unique(labels, dim=0)
|
||||||
|
nclasses = clabels.size()[0]
|
||||||
|
if distances.size()[1] == nclasses:
|
||||||
|
# skip if only one prototype per class
|
||||||
|
return distances
|
||||||
|
batch_size = distances.size()[0]
|
||||||
|
winning_distances = torch.zeros(nclasses, batch_size)
|
||||||
|
inf = torch.full_like(distances.T, fill_value=float('inf'))
|
||||||
|
# distances_to_wpluses = torch.where(matcher, distances, inf)
|
||||||
|
for i, cl in enumerate(clabels):
|
||||||
|
# cdists = distances.T[labels == cl]
|
||||||
|
matcher = torch.eq(labels.unsqueeze(dim=1), cl)
|
||||||
|
if labels.ndim == 2:
|
||||||
|
# if the labels are one-hot vectors
|
||||||
|
matcher = torch.eq(torch.sum(matcher, dim=-1), nclasses)
|
||||||
|
cdists = torch.where(matcher, distances.T, inf).T
|
||||||
|
winning_distances[i] = torch.min(cdists, dim=1,
|
||||||
|
keepdim=True).values.squeeze()
|
||||||
|
if labels.ndim == 2:
|
||||||
|
# Transpose to return with `batch_size` first and
|
||||||
|
# reverse the columns to fix the ordering of the classes
|
||||||
|
return torch.flip(winning_distances.T, dims=(1, ))
|
||||||
|
|
||||||
|
return winning_distances.T # return with `batch_size` first
|
||||||
|
|
||||||
|
|
||||||
|
# @torch.jit.script
|
||||||
def wtac(distances, labels):
|
def wtac(distances, labels):
|
||||||
winning_indices = torch.min(distances, dim=1).indices
|
winning_indices = torch.min(distances, dim=1).indices
|
||||||
winning_labels = labels[winning_indices].squeeze()
|
winning_labels = labels[winning_indices].squeeze()
|
||||||
return winning_labels
|
return winning_labels
|
||||||
|
|
||||||
|
|
||||||
|
# @torch.jit.script
|
||||||
def knnc(distances, labels, k):
|
def knnc(distances, labels, k):
|
||||||
winning_indices = torch.topk(-distances, k=k, dim=1).indices
|
winning_indices = torch.topk(-distances, k=k.item(), dim=1).indices
|
||||||
winning_labels = labels[winning_indices].squeeze()
|
winning_labels = labels[winning_indices].squeeze()
|
||||||
return winning_labels
|
return winning_labels
|
||||||
|
@@ -28,26 +28,19 @@ def euclidean_distance(x, y):
|
|||||||
|
|
||||||
|
|
||||||
def lpnorm_distance(x, y, p):
|
def lpnorm_distance(x, y, p):
|
||||||
"""Compute :math:`{\\langle x, y \\rangle}_p`.
|
r"""Compute :math:`{\langle x, y \rangle}_p`.
|
||||||
|
|
||||||
Expected dimension of x is 2.
|
Expected dimension of x is 2.
|
||||||
Expected dimension of y is 2.
|
Expected dimension of y is 2.
|
||||||
"""
|
"""
|
||||||
# # DEPRECATED in favor of torch.cdist
|
|
||||||
# expanded_x = x.unsqueeze(dim=1)
|
|
||||||
# batchwise_difference = y - expanded_x
|
|
||||||
# differences_raised = torch.pow(batchwise_difference, p)
|
|
||||||
# distances_raised = torch.sum(differences_raised, axis=2)
|
|
||||||
# distances = torch.pow(distances_raised, 1.0 / p)
|
|
||||||
# return distances
|
|
||||||
distances = torch.cdist(x, y, p=p)
|
distances = torch.cdist(x, y, p=p)
|
||||||
return distances
|
return distances
|
||||||
|
|
||||||
|
|
||||||
def omega_distance(x, y, omega):
|
def omega_distance(x, y, omega):
|
||||||
"""Omega distance.
|
r"""Omega distance.
|
||||||
|
|
||||||
Compute :math:`{\\langle \\Omega x, \\Omega y \\rangle}_p`
|
Compute :math:`{\langle \Omega x, \Omega y \rangle}_p`
|
||||||
|
|
||||||
Expected dimension of x is 2.
|
Expected dimension of x is 2.
|
||||||
Expected dimension of y is 2.
|
Expected dimension of y is 2.
|
||||||
@@ -60,9 +53,9 @@ def omega_distance(x, y, omega):
|
|||||||
|
|
||||||
|
|
||||||
def lomega_distance(x, y, omegas):
|
def lomega_distance(x, y, omegas):
|
||||||
"""Localized Omega distance.
|
r"""Localized Omega distance.
|
||||||
|
|
||||||
Compute :math:`{\\langle \\Omega_k x, \\Omega_k y_k \\rangle}_p`
|
Compute :math:`{\langle \Omega_k x, \Omega_k y_k \rangle}_p`
|
||||||
|
|
||||||
Expected dimension of x is 2.
|
Expected dimension of x is 2.
|
||||||
Expected dimension of y is 2.
|
Expected dimension of y is 2.
|
||||||
@@ -76,3 +69,7 @@ def lomega_distance(x, y, omegas):
|
|||||||
distances = torch.sum(differences_squared, dim=2)
|
distances = torch.sum(differences_squared, dim=2)
|
||||||
distances = distances.permute(1, 0)
|
distances = distances.permute(1, 0)
|
||||||
return distances
|
return distances
|
||||||
|
|
||||||
|
|
||||||
|
# Aliases
|
||||||
|
sed = squared_euclidean_distance
|
||||||
|
@@ -7,87 +7,97 @@ import torch
|
|||||||
INITIALIZERS = dict()
|
INITIALIZERS = dict()
|
||||||
|
|
||||||
|
|
||||||
def register_initializer(func):
|
def register_initializer(function):
|
||||||
INITIALIZERS[func.__name__] = func
|
"""Add the initializer to the registry."""
|
||||||
return func
|
INITIALIZERS[function.__name__] = function
|
||||||
|
return function
|
||||||
|
|
||||||
|
|
||||||
def labels_from(distribution):
|
def labels_from(distribution, one_hot=True):
|
||||||
"""Takes a distribution tensor and returns a labels tensor."""
|
"""Takes a distribution tensor and returns a labels tensor."""
|
||||||
nclasses = distribution.shape[0]
|
nclasses = distribution.shape[0]
|
||||||
llist = [[i] * n for i, n in zip(range(nclasses), distribution)]
|
llist = [[i] * n for i, n in zip(range(nclasses), distribution)]
|
||||||
# labels = [l for cl in llist for l in cl] # flatten the list of lists
|
# labels = [l for cl in llist for l in cl] # flatten the list of lists
|
||||||
labels = list(chain(*llist)) # flatten using itertools.chain
|
flat_llist = list(chain(*llist)) # flatten label list with itertools.chain
|
||||||
return torch.tensor(labels, requires_grad=False)
|
plabels = torch.tensor(flat_llist, requires_grad=False)
|
||||||
|
if one_hot:
|
||||||
|
return torch.eye(nclasses)[plabels]
|
||||||
|
return plabels
|
||||||
|
|
||||||
|
|
||||||
@register_initializer
|
@register_initializer
|
||||||
def ones(x_train, y_train, prototype_distribution):
|
def ones(x_train, y_train, prototype_distribution, one_hot=True):
|
||||||
nprotos = torch.sum(prototype_distribution)
|
nprotos = torch.sum(prototype_distribution)
|
||||||
protos = torch.ones(nprotos, *x_train.shape[1:])
|
protos = torch.ones(nprotos, *x_train.shape[1:])
|
||||||
plabels = labels_from(prototype_distribution)
|
plabels = labels_from(prototype_distribution, one_hot)
|
||||||
return protos, plabels
|
return protos, plabels
|
||||||
|
|
||||||
|
|
||||||
@register_initializer
|
@register_initializer
|
||||||
def zeros(x_train, y_train, prototype_distribution):
|
def zeros(x_train, y_train, prototype_distribution, one_hot=True):
|
||||||
nprotos = torch.sum(prototype_distribution)
|
nprotos = torch.sum(prototype_distribution)
|
||||||
protos = torch.zeros(nprotos, *x_train.shape[1:])
|
protos = torch.zeros(nprotos, *x_train.shape[1:])
|
||||||
plabels = labels_from(prototype_distribution)
|
plabels = labels_from(prototype_distribution, one_hot)
|
||||||
return protos, plabels
|
return protos, plabels
|
||||||
|
|
||||||
|
|
||||||
@register_initializer
|
@register_initializer
|
||||||
def rand(x_train, y_train, prototype_distribution):
|
def rand(x_train, y_train, prototype_distribution, one_hot=True):
|
||||||
nprotos = torch.sum(prototype_distribution)
|
nprotos = torch.sum(prototype_distribution)
|
||||||
protos = torch.rand(nprotos, *x_train.shape[1:])
|
protos = torch.rand(nprotos, *x_train.shape[1:])
|
||||||
plabels = labels_from(prototype_distribution)
|
plabels = labels_from(prototype_distribution, one_hot)
|
||||||
return protos, plabels
|
return protos, plabels
|
||||||
|
|
||||||
|
|
||||||
@register_initializer
|
@register_initializer
|
||||||
def randn(x_train, y_train, prototype_distribution):
|
def randn(x_train, y_train, prototype_distribution, one_hot=True):
|
||||||
nprotos = torch.sum(prototype_distribution)
|
nprotos = torch.sum(prototype_distribution)
|
||||||
protos = torch.randn(nprotos, *x_train.shape[1:])
|
protos = torch.randn(nprotos, *x_train.shape[1:])
|
||||||
plabels = labels_from(prototype_distribution)
|
plabels = labels_from(prototype_distribution, one_hot)
|
||||||
return protos, plabels
|
return protos, plabels
|
||||||
|
|
||||||
|
|
||||||
@register_initializer
|
@register_initializer
|
||||||
def stratified_mean(x_train, y_train, prototype_distribution):
|
def stratified_mean(x_train, y_train, prototype_distribution, one_hot=True):
|
||||||
nprotos = torch.sum(prototype_distribution)
|
nprotos = torch.sum(prototype_distribution)
|
||||||
pdim = x_train.shape[1]
|
pdim = x_train.shape[1]
|
||||||
protos = torch.empty(nprotos, pdim)
|
protos = torch.empty(nprotos, pdim)
|
||||||
plabels = labels_from(prototype_distribution)
|
plabels = labels_from(prototype_distribution, one_hot)
|
||||||
for i, l in enumerate(plabels):
|
for i, label in enumerate(plabels):
|
||||||
xl = x_train[y_train == l]
|
matcher = torch.eq(label.unsqueeze(dim=0), y_train)
|
||||||
|
if one_hot:
|
||||||
|
nclasses = y_train.size()[1]
|
||||||
|
matcher = torch.eq(torch.sum(matcher, dim=-1), nclasses)
|
||||||
|
xl = x_train[matcher]
|
||||||
mean_xl = torch.mean(xl, dim=0)
|
mean_xl = torch.mean(xl, dim=0)
|
||||||
protos[i] = mean_xl
|
protos[i] = mean_xl
|
||||||
|
plabels = labels_from(prototype_distribution, one_hot=one_hot)
|
||||||
return protos, plabels
|
return protos, plabels
|
||||||
|
|
||||||
|
|
||||||
@register_initializer
|
@register_initializer
|
||||||
def stratified_random(x_train, y_train, prototype_distribution):
|
def stratified_random(x_train, y_train, prototype_distribution, one_hot=True):
|
||||||
gen = torch.manual_seed(torch.initial_seed())
|
|
||||||
nprotos = torch.sum(prototype_distribution)
|
nprotos = torch.sum(prototype_distribution)
|
||||||
pdim = x_train.shape[1]
|
pdim = x_train.shape[1]
|
||||||
protos = torch.empty(nprotos, pdim)
|
protos = torch.empty(nprotos, pdim)
|
||||||
plabels = labels_from(prototype_distribution)
|
plabels = labels_from(prototype_distribution, one_hot)
|
||||||
for i, l in enumerate(plabels):
|
for i, label in enumerate(plabels):
|
||||||
xl = x_train[y_train == l]
|
matcher = torch.eq(label.unsqueeze(dim=0), y_train)
|
||||||
rand_index = torch.zeros(1).long().random_(0,
|
if one_hot:
|
||||||
xl.shape[1] - 1,
|
nclasses = y_train.size()[1]
|
||||||
generator=gen)
|
matcher = torch.eq(torch.sum(matcher, dim=-1), nclasses)
|
||||||
|
xl = x_train[matcher]
|
||||||
|
rand_index = torch.zeros(1).long().random_(0, xl.shape[0] - 1)
|
||||||
random_xl = xl[rand_index]
|
random_xl = xl[rand_index]
|
||||||
protos[i] = random_xl
|
protos[i] = random_xl
|
||||||
|
plabels = labels_from(prototype_distribution, one_hot=one_hot)
|
||||||
return protos, plabels
|
return protos, plabels
|
||||||
|
|
||||||
|
|
||||||
def get_initializer(funcname):
|
def get_initializer(funcname):
|
||||||
|
"""Deserialize the initializer."""
|
||||||
if callable(funcname):
|
if callable(funcname):
|
||||||
return funcname
|
return funcname
|
||||||
else:
|
if funcname in INITIALIZERS:
|
||||||
if funcname in INITIALIZERS:
|
return INITIALIZERS.get(funcname)
|
||||||
return INITIALIZERS.get(funcname)
|
raise NameError(f'Initializer {funcname} was not found.')
|
||||||
else:
|
|
||||||
raise NameError(f'Initializer {funcname} was not found.')
|
|
||||||
|
@@ -3,23 +3,24 @@
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
def glvq_loss(distances, target_labels, prototype_labels):
|
def _get_dp_dm(distances, targets, plabels):
|
||||||
"""GLVQ loss function with support for one-hot labels."""
|
matcher = torch.eq(targets.unsqueeze(dim=1), plabels)
|
||||||
matcher = torch.eq(target_labels.unsqueeze(dim=1), prototype_labels)
|
if plabels.ndim == 2:
|
||||||
if prototype_labels.ndim == 2:
|
|
||||||
# if the labels are one-hot vectors
|
# if the labels are one-hot vectors
|
||||||
nclasses = target_labels.size()[1]
|
nclasses = targets.size()[1]
|
||||||
matcher = torch.eq(torch.sum(matcher, dim=-1), nclasses)
|
matcher = torch.eq(torch.sum(matcher, dim=-1), nclasses)
|
||||||
not_matcher = torch.bitwise_not(matcher)
|
not_matcher = torch.bitwise_not(matcher)
|
||||||
|
|
||||||
dplus_criterion = distances * matcher > 0.0
|
|
||||||
dminus_criterion = distances * not_matcher > 0.0
|
|
||||||
|
|
||||||
inf = torch.full_like(distances, fill_value=float('inf'))
|
inf = torch.full_like(distances, fill_value=float('inf'))
|
||||||
distances_to_wpluses = torch.where(dplus_criterion, distances, inf)
|
d_matching = torch.where(matcher, distances, inf)
|
||||||
distances_to_wminuses = torch.where(dminus_criterion, distances, inf)
|
d_unmatching = torch.where(not_matcher, distances, inf)
|
||||||
dpluses = torch.min(distances_to_wpluses, dim=1, keepdim=True).values
|
dp = torch.min(d_matching, dim=1, keepdim=True).values
|
||||||
dminuses = torch.min(distances_to_wminuses, dim=1, keepdim=True).values
|
dm = torch.min(d_unmatching, dim=1, keepdim=True).values
|
||||||
|
return dp, dm
|
||||||
|
|
||||||
mu = (dpluses - dminuses) / (dpluses + dminuses)
|
|
||||||
|
def glvq_loss(distances, target_labels, prototype_labels):
|
||||||
|
"""GLVQ loss function with support for one-hot labels."""
|
||||||
|
dp, dm = _get_dp_dm(distances, target_labels, prototype_labels)
|
||||||
|
mu = (dp - dm) / (dp + dm)
|
||||||
return mu
|
return mu
|
||||||
|
@@ -0,0 +1,7 @@
|
|||||||
|
"""ProtoTorch modules."""
|
||||||
|
|
||||||
|
from .prototypes import Prototypes1D
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
'Prototypes1D',
|
||||||
|
]
|
||||||
|
@@ -7,12 +7,11 @@ from prototorch.functions.losses import glvq_loss
|
|||||||
|
|
||||||
|
|
||||||
class GLVQLoss(torch.nn.Module):
|
class GLVQLoss(torch.nn.Module):
|
||||||
"""GLVQ Loss."""
|
|
||||||
def __init__(self, margin=0.0, squashing='identity', beta=10, **kwargs):
|
def __init__(self, margin=0.0, squashing='identity', beta=10, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self.margin = margin
|
self.margin = margin
|
||||||
self.squashing = get_activation(squashing)
|
self.squashing = get_activation(squashing)
|
||||||
self.beta = beta
|
self.beta = torch.tensor(beta)
|
||||||
|
|
||||||
def forward(self, outputs, targets):
|
def forward(self, outputs, targets):
|
||||||
distances, plabels = outputs
|
distances, plabels = outputs
|
||||||
|
@@ -1,57 +1,165 @@
|
|||||||
"""ProtoTorch prototype modules."""
|
"""ProtoTorch prototype modules."""
|
||||||
|
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from prototorch.functions.competitions import wtac
|
||||||
|
from prototorch.functions.distances import sed
|
||||||
from prototorch.functions.initializers import get_initializer
|
from prototorch.functions.initializers import get_initializer
|
||||||
|
|
||||||
|
|
||||||
class AddPrototypes1D(torch.nn.Module):
|
class _Prototypes(torch.nn.Module):
|
||||||
|
"""Abstract prototypes class."""
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
def _validate_prototype_distribution(self):
|
||||||
|
if 0 in self.prototype_distribution:
|
||||||
|
warnings.warn('Are you sure about the `0` in '
|
||||||
|
'`prototype_distribution`?')
|
||||||
|
|
||||||
|
def extra_repr(self):
|
||||||
|
return f'prototypes.shape: {tuple(self.prototypes.shape)}'
|
||||||
|
|
||||||
|
def forward(self):
|
||||||
|
return self.prototypes, self.prototype_labels
|
||||||
|
|
||||||
|
|
||||||
|
class Prototypes1D(_Prototypes):
|
||||||
|
r"""Create a learnable set of one-dimensional prototypes.
|
||||||
|
|
||||||
|
TODO Complete this doc-string
|
||||||
|
|
||||||
|
Kwargs:
|
||||||
|
prototypes_per_class: number of prototypes to use per class.
|
||||||
|
Default: ``1``
|
||||||
|
prototype_initializer: prototype initializer.
|
||||||
|
Default: ``'ones'``
|
||||||
|
prototype_distribution: prototype distribution vector.
|
||||||
|
Default: ``None``
|
||||||
|
input_dim: dimension of the incoming data.
|
||||||
|
nclasses: number of classes.
|
||||||
|
data: If set to ``None``, data-dependent initializers will be ignored.
|
||||||
|
Default: ``None``
|
||||||
|
|
||||||
|
Shape:
|
||||||
|
- Input: :math:`(N, H_{in})`
|
||||||
|
where :math:`H_{in} = \text{input_dim}`.
|
||||||
|
- Output: :math:`(N, H_{out})`
|
||||||
|
where :math:`H_{out} = \text{total_prototypes}`.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
prototypes: the learnable weights of the module of shape
|
||||||
|
:math:`(\text{total_prototypes}, \text{prototype_dimension})`.
|
||||||
|
prototype_labels: the non-learnable labels of the prototypes.
|
||||||
|
|
||||||
|
Examples::
|
||||||
|
|
||||||
|
>>> p = Prototypes1D(input_dim=20, nclasses=10)
|
||||||
|
>>> input = torch.randn(128, 20)
|
||||||
|
>>> output = m(input)
|
||||||
|
>>> print(output.size())
|
||||||
|
torch.Size([20, 10])
|
||||||
|
"""
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
prototypes_per_class=1,
|
prototypes_per_class=1,
|
||||||
prototype_distribution=None,
|
|
||||||
prototype_initializer='ones',
|
prototype_initializer='ones',
|
||||||
|
prototype_distribution=None,
|
||||||
data=None,
|
data=None,
|
||||||
|
dtype=torch.float32,
|
||||||
|
one_hot_labels=False,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
|
|
||||||
|
# Convert tensors to python lists before processing
|
||||||
|
if prototype_distribution is not None:
|
||||||
|
if not isinstance(prototype_distribution, list):
|
||||||
|
prototype_distribution = prototype_distribution.tolist()
|
||||||
|
|
||||||
if data is None:
|
if data is None:
|
||||||
if 'input_dim' not in kwargs:
|
if 'input_dim' not in kwargs:
|
||||||
raise NameError('`input_dim` required if '
|
raise NameError('`input_dim` required if '
|
||||||
'no `data` is provided.')
|
'no `data` is provided.')
|
||||||
if prototype_distribution is not None:
|
if prototype_distribution:
|
||||||
nclasses = sum(prototype_distribution)
|
kwargs_nclasses = sum(prototype_distribution)
|
||||||
else:
|
else:
|
||||||
if 'nclasses' not in kwargs:
|
if 'nclasses' not in kwargs:
|
||||||
raise NameError('`prototype_distribution` required if '
|
raise NameError('`prototype_distribution` required if '
|
||||||
'both `data` and `nclasses` are not '
|
'both `data` and `nclasses` are not '
|
||||||
'provided.')
|
'provided.')
|
||||||
nclasses = kwargs.pop('nclasses')
|
kwargs_nclasses = kwargs.pop('nclasses')
|
||||||
input_dim = kwargs.pop('input_dim')
|
input_dim = kwargs.pop('input_dim')
|
||||||
# input_shape = (input_dim, )
|
if prototype_initializer in [
|
||||||
x_train = torch.rand(nclasses, input_dim)
|
'stratified_mean', 'stratified_random'
|
||||||
y_train = torch.arange(nclasses)
|
]:
|
||||||
|
warnings.warn(
|
||||||
|
f'`prototype_initializer`: `{prototype_initializer}` '
|
||||||
|
'requires `data`, but `data` is not provided. '
|
||||||
|
'Using randomly generated data instead.')
|
||||||
|
x_train = torch.rand(kwargs_nclasses, input_dim)
|
||||||
|
y_train = torch.arange(kwargs_nclasses)
|
||||||
|
if one_hot_labels:
|
||||||
|
y_train = torch.eye(kwargs_nclasses)[y_train]
|
||||||
|
data = [x_train, y_train]
|
||||||
|
|
||||||
else:
|
x_train, y_train = data
|
||||||
x_train, y_train = data
|
x_train = torch.as_tensor(x_train).type(dtype)
|
||||||
x_train = torch.as_tensor(x_train)
|
y_train = torch.as_tensor(y_train).type(torch.int)
|
||||||
y_train = torch.as_tensor(y_train)
|
nclasses = torch.unique(y_train, dim=-1).shape[-1]
|
||||||
|
|
||||||
|
if nclasses == 1:
|
||||||
|
warnings.warn('Are you sure about having one class only?')
|
||||||
|
|
||||||
|
if x_train.ndim != 2:
|
||||||
|
raise ValueError('`data[0].ndim != 2`.')
|
||||||
|
|
||||||
|
if y_train.ndim == 2:
|
||||||
|
if y_train.shape[1] == 1 and one_hot_labels:
|
||||||
|
raise ValueError('`one_hot_labels` is set to `True` '
|
||||||
|
'but target labels are not one-hot-encoded.')
|
||||||
|
if y_train.shape[1] != 1 and not one_hot_labels:
|
||||||
|
raise ValueError('`one_hot_labels` is set to `False` '
|
||||||
|
'but target labels in `data` '
|
||||||
|
'are one-hot-encoded.')
|
||||||
|
if y_train.ndim == 1 and one_hot_labels:
|
||||||
|
raise ValueError('`one_hot_labels` is set to `True` '
|
||||||
|
'but target labels are not one-hot-encoded.')
|
||||||
|
|
||||||
|
# Verify input dimension if `input_dim` is provided
|
||||||
|
if 'input_dim' in kwargs:
|
||||||
|
input_dim = kwargs.pop('input_dim')
|
||||||
|
if input_dim != x_train.shape[1]:
|
||||||
|
raise ValueError(f'Provided `input_dim`={input_dim} does '
|
||||||
|
'not match data dimension '
|
||||||
|
f'`data[0].shape[1]`={x_train.shape[1]}')
|
||||||
|
|
||||||
|
# Verify the number of classes if `nclasses` is provided
|
||||||
|
if 'nclasses' in kwargs:
|
||||||
|
kwargs_nclasses = kwargs.pop('nclasses')
|
||||||
|
if kwargs_nclasses != nclasses:
|
||||||
|
raise ValueError(f'Provided `nclasses={kwargs_nclasses}` does '
|
||||||
|
'not match data labels '
|
||||||
|
'`torch.unique(data[1]).shape[0]`'
|
||||||
|
f'={nclasses}')
|
||||||
|
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self.prototypes_per_class = prototypes_per_class
|
|
||||||
|
if not prototype_distribution:
|
||||||
|
prototype_distribution = [prototypes_per_class] * nclasses
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
if not prototype_distribution:
|
self.prototype_distribution = torch.tensor(prototype_distribution)
|
||||||
num_classes = torch.unique(y_train).shape[0]
|
|
||||||
self.prototype_distribution = torch.tensor(
|
self._validate_prototype_distribution()
|
||||||
[self.prototypes_per_class] * num_classes)
|
|
||||||
else:
|
|
||||||
self.prototype_distribution = torch.tensor(
|
|
||||||
prototype_distribution)
|
|
||||||
self.prototype_initializer = get_initializer(prototype_initializer)
|
self.prototype_initializer = get_initializer(prototype_initializer)
|
||||||
prototypes, prototype_labels = self.prototype_initializer(
|
prototypes, prototype_labels = self.prototype_initializer(
|
||||||
x_train,
|
x_train,
|
||||||
y_train,
|
y_train,
|
||||||
prototype_distribution=self.prototype_distribution)
|
prototype_distribution=self.prototype_distribution,
|
||||||
|
one_hot=one_hot_labels,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Register module parameters
|
||||||
self.prototypes = torch.nn.Parameter(prototypes)
|
self.prototypes = torch.nn.Parameter(prototypes)
|
||||||
self.prototype_labels = prototype_labels
|
self.prototype_labels = prototype_labels
|
||||||
|
|
||||||
def forward(self):
|
|
||||||
return self.prototypes, self.prototype_labels
|
|
||||||
|
5
requirements.txt
Normal file
5
requirements.txt
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
matplotlib==3.1.2
|
||||||
|
pytest==5.3.4
|
||||||
|
requests==2.22.0
|
||||||
|
codecov==2.0.22
|
||||||
|
tqdm==4.44.1
|
3
setup.py
3
setup.py
@@ -10,7 +10,7 @@ with open('README.md', 'r') as fh:
|
|||||||
long_description = fh.read()
|
long_description = fh.read()
|
||||||
|
|
||||||
setup(name='prototorch',
|
setup(name='prototorch',
|
||||||
version='0.1.0-dev0',
|
version='0.1.1-rc0',
|
||||||
description='Highly extensible, GPU-supported '
|
description='Highly extensible, GPU-supported '
|
||||||
'Learning Vector Quantization (LVQ) toolbox '
|
'Learning Vector Quantization (LVQ) toolbox '
|
||||||
'built using PyTorch and its nn API.',
|
'built using PyTorch and its nn API.',
|
||||||
@@ -27,6 +27,7 @@ setup(name='prototorch',
|
|||||||
'numpy>=1.9.1',
|
'numpy>=1.9.1',
|
||||||
],
|
],
|
||||||
extras_require={
|
extras_require={
|
||||||
|
'datasets': ['requests'],
|
||||||
'examples': [
|
'examples': [
|
||||||
'sklearn',
|
'sklearn',
|
||||||
'matplotlib',
|
'matplotlib',
|
||||||
|
0
tests/__init__.py
Normal file
0
tests/__init__.py
Normal file
95
tests/test_datasets.py
Normal file
95
tests/test_datasets.py
Normal file
@@ -0,0 +1,95 @@
|
|||||||
|
"""ProtoTorch datasets test suite."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from prototorch.datasets import abstract, tecator
|
||||||
|
|
||||||
|
|
||||||
|
class TestAbstract(unittest.TestCase):
|
||||||
|
def test_getitem(self):
|
||||||
|
with self.assertRaises(NotImplementedError):
|
||||||
|
abstract.Dataset('./artifacts')[0]
|
||||||
|
|
||||||
|
def test_len(self):
|
||||||
|
with self.assertRaises(NotImplementedError):
|
||||||
|
len(abstract.Dataset('./artifacts'))
|
||||||
|
|
||||||
|
|
||||||
|
class TestProtoDataset(unittest.TestCase):
|
||||||
|
def test_getitem(self):
|
||||||
|
with self.assertRaises(NotImplementedError):
|
||||||
|
abstract.ProtoDataset('./artifacts')[0]
|
||||||
|
|
||||||
|
def test_download(self):
|
||||||
|
with self.assertRaises(NotImplementedError):
|
||||||
|
abstract.ProtoDataset('./artifacts').download()
|
||||||
|
|
||||||
|
|
||||||
|
class TestTecator(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
self.artifacts_dir = './artifacts/Tecator'
|
||||||
|
self._remove_artifacts()
|
||||||
|
|
||||||
|
def _remove_artifacts(self):
|
||||||
|
if os.path.exists(self.artifacts_dir):
|
||||||
|
shutil.rmtree(self.artifacts_dir)
|
||||||
|
|
||||||
|
def test_download_false(self):
|
||||||
|
rootdir = self.artifacts_dir.rpartition('/')[0]
|
||||||
|
self._remove_artifacts()
|
||||||
|
with self.assertRaises(RuntimeError):
|
||||||
|
_ = tecator.Tecator(rootdir, download=False)
|
||||||
|
|
||||||
|
def test_download_caching(self):
|
||||||
|
rootdir = self.artifacts_dir.rpartition('/')[0]
|
||||||
|
_ = tecator.Tecator(rootdir, download=True, verbose=False)
|
||||||
|
_ = tecator.Tecator(rootdir, download=False, verbose=False)
|
||||||
|
|
||||||
|
def test_repr(self):
|
||||||
|
rootdir = self.artifacts_dir.rpartition('/')[0]
|
||||||
|
train = tecator.Tecator(rootdir, download=True, verbose=True)
|
||||||
|
self.assertTrue('Split: Train' in train.__repr__())
|
||||||
|
|
||||||
|
def test_download_train(self):
|
||||||
|
rootdir = self.artifacts_dir.rpartition('/')[0]
|
||||||
|
train = tecator.Tecator(root=rootdir,
|
||||||
|
train=True,
|
||||||
|
download=True,
|
||||||
|
verbose=False)
|
||||||
|
train = tecator.Tecator(root=rootdir, download=True, verbose=False)
|
||||||
|
x_train, y_train = train.data, train.targets
|
||||||
|
self.assertEqual(x_train.shape[0], 144)
|
||||||
|
self.assertEqual(y_train.shape[0], 144)
|
||||||
|
self.assertEqual(x_train.shape[1], 100)
|
||||||
|
|
||||||
|
def test_download_test(self):
|
||||||
|
rootdir = self.artifacts_dir.rpartition('/')[0]
|
||||||
|
test = tecator.Tecator(root=rootdir, train=False, verbose=False)
|
||||||
|
x_test, y_test = test.data, test.targets
|
||||||
|
self.assertEqual(x_test.shape[0], 71)
|
||||||
|
self.assertEqual(y_test.shape[0], 71)
|
||||||
|
self.assertEqual(x_test.shape[1], 100)
|
||||||
|
|
||||||
|
def test_class_to_idx(self):
|
||||||
|
rootdir = self.artifacts_dir.rpartition('/')[0]
|
||||||
|
test = tecator.Tecator(root=rootdir, train=False, verbose=False)
|
||||||
|
_ = test.class_to_idx
|
||||||
|
|
||||||
|
def test_getitem(self):
|
||||||
|
rootdir = self.artifacts_dir.rpartition('/')[0]
|
||||||
|
test = tecator.Tecator(root=rootdir, train=False, verbose=False)
|
||||||
|
x, y = test[0]
|
||||||
|
self.assertEqual(x.shape[0], 100)
|
||||||
|
self.assertIsInstance(y, int)
|
||||||
|
|
||||||
|
def test_loadable_with_dataloader(self):
|
||||||
|
rootdir = self.artifacts_dir.rpartition('/')[0]
|
||||||
|
test = tecator.Tecator(root=rootdir, train=False, verbose=False)
|
||||||
|
_ = torch.utils.data.DataLoader(test, batch_size=64, shuffle=True)
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
pass
|
@@ -6,7 +6,148 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from prototorch.functions import (activations, competitions, distances,
|
from prototorch.functions import (activations, competitions, distances,
|
||||||
initializers)
|
initializers, losses)
|
||||||
|
|
||||||
|
|
||||||
|
class TestActivations(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
self.flist = ['identity', 'sigmoid_beta', 'swish_beta']
|
||||||
|
self.x = torch.randn(1024, 1)
|
||||||
|
|
||||||
|
def test_registry(self):
|
||||||
|
self.assertIsNotNone(activations.ACTIVATIONS)
|
||||||
|
|
||||||
|
def test_funcname_deserialization(self):
|
||||||
|
for funcname in self.flist:
|
||||||
|
f = activations.get_activation(funcname)
|
||||||
|
iscallable = callable(f)
|
||||||
|
self.assertTrue(iscallable)
|
||||||
|
|
||||||
|
# def test_torch_script(self):
|
||||||
|
# for funcname in self.flist:
|
||||||
|
# f = activations.get_activation(funcname)
|
||||||
|
# self.assertIsInstance(f, torch.jit.ScriptFunction)
|
||||||
|
|
||||||
|
def test_callable_deserialization(self):
|
||||||
|
def dummy(x, **kwargs):
|
||||||
|
return x
|
||||||
|
|
||||||
|
for f in [dummy, lambda x: x]:
|
||||||
|
f = activations.get_activation(f)
|
||||||
|
iscallable = callable(f)
|
||||||
|
self.assertTrue(iscallable)
|
||||||
|
self.assertEqual(1, f(1))
|
||||||
|
|
||||||
|
def test_unknown_deserialization(self):
|
||||||
|
for funcname in ['blubb', 'foobar']:
|
||||||
|
with self.assertRaises(NameError):
|
||||||
|
_ = activations.get_activation(funcname)
|
||||||
|
|
||||||
|
def test_identity(self):
|
||||||
|
actual = activations.identity(self.x)
|
||||||
|
desired = self.x
|
||||||
|
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||||
|
desired,
|
||||||
|
decimal=5)
|
||||||
|
self.assertIsNone(mismatch)
|
||||||
|
|
||||||
|
def test_sigmoid_beta1(self):
|
||||||
|
actual = activations.sigmoid_beta(self.x, beta=torch.tensor(1))
|
||||||
|
desired = torch.sigmoid(self.x)
|
||||||
|
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||||
|
desired,
|
||||||
|
decimal=5)
|
||||||
|
self.assertIsNone(mismatch)
|
||||||
|
|
||||||
|
def test_swish_beta1(self):
|
||||||
|
actual = activations.swish_beta(self.x, beta=torch.tensor(1))
|
||||||
|
desired = self.x * torch.sigmoid(self.x)
|
||||||
|
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||||
|
desired,
|
||||||
|
decimal=5)
|
||||||
|
self.assertIsNone(mismatch)
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
del self.x
|
||||||
|
|
||||||
|
|
||||||
|
class TestCompetitions(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def test_wtac(self):
|
||||||
|
d = torch.tensor([[2., 3., 1.99, 3.01], [2., 3., 2.01, 3.]])
|
||||||
|
labels = torch.tensor([0, 1, 2, 3])
|
||||||
|
actual = competitions.wtac(d, labels)
|
||||||
|
desired = torch.tensor([2, 0])
|
||||||
|
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||||
|
desired,
|
||||||
|
decimal=5)
|
||||||
|
self.assertIsNone(mismatch)
|
||||||
|
|
||||||
|
def test_wtac_unequal_dist(self):
|
||||||
|
d = torch.tensor([[2., 3., 4.], [2., 3., 1.]])
|
||||||
|
labels = torch.tensor([0, 1, 1])
|
||||||
|
actual = competitions.wtac(d, labels)
|
||||||
|
desired = torch.tensor([0, 1])
|
||||||
|
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||||
|
desired,
|
||||||
|
decimal=5)
|
||||||
|
self.assertIsNone(mismatch)
|
||||||
|
|
||||||
|
def test_wtac_one_hot(self):
|
||||||
|
d = torch.tensor([[1.99, 3.01], [3., 2.01]])
|
||||||
|
labels = torch.tensor([[0, 1], [1, 0]])
|
||||||
|
actual = competitions.wtac(d, labels)
|
||||||
|
desired = torch.tensor([[0, 1], [1, 0]])
|
||||||
|
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||||
|
desired,
|
||||||
|
decimal=5)
|
||||||
|
self.assertIsNone(mismatch)
|
||||||
|
|
||||||
|
def test_stratified_min(self):
|
||||||
|
d = torch.tensor([[1., 0., 2., 3.], [9., 8., 0, 1]])
|
||||||
|
labels = torch.tensor([0, 0, 1, 2])
|
||||||
|
actual = competitions.stratified_min(d, labels)
|
||||||
|
desired = torch.tensor([[0., 2., 3.], [8., 0., 1.]])
|
||||||
|
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||||
|
desired,
|
||||||
|
decimal=5)
|
||||||
|
self.assertIsNone(mismatch)
|
||||||
|
|
||||||
|
def test_stratified_min_one_hot(self):
|
||||||
|
d = torch.tensor([[1., 0., 2., 3.], [9., 8., 0, 1]])
|
||||||
|
labels = torch.tensor([0, 0, 1, 2])
|
||||||
|
labels = torch.eye(3)[labels]
|
||||||
|
actual = competitions.stratified_min(d, labels)
|
||||||
|
desired = torch.tensor([[0., 2., 3.], [8., 0., 1.]])
|
||||||
|
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||||
|
desired,
|
||||||
|
decimal=5)
|
||||||
|
self.assertIsNone(mismatch)
|
||||||
|
|
||||||
|
def test_stratified_min_simple(self):
|
||||||
|
d = torch.tensor([[0., 2., 3.], [8., 0, 1]])
|
||||||
|
labels = torch.tensor([0, 1, 2])
|
||||||
|
actual = competitions.stratified_min(d, labels)
|
||||||
|
desired = torch.tensor([[0., 2., 3.], [8., 0., 1.]])
|
||||||
|
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||||
|
desired,
|
||||||
|
decimal=5)
|
||||||
|
self.assertIsNone(mismatch)
|
||||||
|
|
||||||
|
def test_knnc_k1(self):
|
||||||
|
d = torch.tensor([[2., 3., 1.99, 3.01], [2., 3., 2.01, 3.]])
|
||||||
|
labels = torch.tensor([0, 1, 2, 3])
|
||||||
|
actual = competitions.knnc(d, labels, k=torch.tensor([1]))
|
||||||
|
desired = torch.tensor([2, 0])
|
||||||
|
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||||
|
desired,
|
||||||
|
decimal=5)
|
||||||
|
self.assertIsNone(mismatch)
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class TestDistances(unittest.TestCase):
|
class TestDistances(unittest.TestCase):
|
||||||
@@ -167,103 +308,12 @@ class TestDistances(unittest.TestCase):
|
|||||||
del self.x, self.y
|
del self.x, self.y
|
||||||
|
|
||||||
|
|
||||||
class TestActivations(unittest.TestCase):
|
|
||||||
def setUp(self):
|
|
||||||
self.x = torch.randn(1024, 1)
|
|
||||||
|
|
||||||
def test_registry(self):
|
|
||||||
self.assertIsNotNone(activations.ACTIVATIONS)
|
|
||||||
|
|
||||||
def test_funcname_deserialization(self):
|
|
||||||
flist = ['identity', 'sigmoid_beta', 'swish_beta']
|
|
||||||
for funcname in flist:
|
|
||||||
f = activations.get_activation(funcname)
|
|
||||||
iscallable = callable(f)
|
|
||||||
self.assertTrue(iscallable)
|
|
||||||
|
|
||||||
def test_callable_deserialization(self):
|
|
||||||
def dummy(x, **kwargs):
|
|
||||||
return x
|
|
||||||
|
|
||||||
for f in [dummy, lambda x: x]:
|
|
||||||
f = activations.get_activation(f)
|
|
||||||
iscallable = callable(f)
|
|
||||||
self.assertTrue(iscallable)
|
|
||||||
self.assertEqual(1, f(1))
|
|
||||||
|
|
||||||
def test_unknown_deserialization(self):
|
|
||||||
for funcname in ['blubb', 'foobar']:
|
|
||||||
with self.assertRaises(NameError):
|
|
||||||
_ = activations.get_activation(funcname)
|
|
||||||
|
|
||||||
def test_identity(self):
|
|
||||||
actual = activations.identity(self.x)
|
|
||||||
desired = self.x
|
|
||||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
|
||||||
desired,
|
|
||||||
decimal=5)
|
|
||||||
self.assertIsNone(mismatch)
|
|
||||||
|
|
||||||
def test_sigmoid_beta1(self):
|
|
||||||
actual = activations.sigmoid_beta(self.x, beta=1)
|
|
||||||
desired = torch.sigmoid(self.x)
|
|
||||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
|
||||||
desired,
|
|
||||||
decimal=5)
|
|
||||||
self.assertIsNone(mismatch)
|
|
||||||
|
|
||||||
def test_swish_beta1(self):
|
|
||||||
actual = activations.swish_beta(self.x, beta=1)
|
|
||||||
desired = self.x * torch.sigmoid(self.x)
|
|
||||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
|
||||||
desired,
|
|
||||||
decimal=5)
|
|
||||||
self.assertIsNone(mismatch)
|
|
||||||
|
|
||||||
def tearDown(self):
|
|
||||||
del self.x
|
|
||||||
|
|
||||||
|
|
||||||
class TestCompetitions(unittest.TestCase):
|
|
||||||
def setUp(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def test_wtac(self):
|
|
||||||
d = torch.tensor([[2., 3., 1.99, 3.01], [2., 3., 2.01, 3.]])
|
|
||||||
labels = torch.tensor([0, 1, 2, 3])
|
|
||||||
actual = competitions.wtac(d, labels)
|
|
||||||
desired = torch.tensor([2, 0])
|
|
||||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
|
||||||
desired,
|
|
||||||
decimal=5)
|
|
||||||
self.assertIsNone(mismatch)
|
|
||||||
|
|
||||||
def test_wtac_one_hot(self):
|
|
||||||
d = torch.tensor([[1.99, 3.01], [3., 2.01]])
|
|
||||||
labels = torch.tensor([[0, 1], [1, 0]])
|
|
||||||
actual = competitions.wtac(d, labels)
|
|
||||||
desired = torch.tensor([[0, 1], [1, 0]])
|
|
||||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
|
||||||
desired,
|
|
||||||
decimal=5)
|
|
||||||
self.assertIsNone(mismatch)
|
|
||||||
|
|
||||||
def test_knnc_k1(self):
|
|
||||||
d = torch.tensor([[2., 3., 1.99, 3.01], [2., 3., 2.01, 3.]])
|
|
||||||
labels = torch.tensor([0, 1, 2, 3])
|
|
||||||
actual = competitions.knnc(d, labels, k=1)
|
|
||||||
desired = torch.tensor([2, 0])
|
|
||||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
|
||||||
desired,
|
|
||||||
decimal=5)
|
|
||||||
self.assertIsNone(mismatch)
|
|
||||||
|
|
||||||
def tearDown(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class TestInitializers(unittest.TestCase):
|
class TestInitializers(unittest.TestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
|
self.flist = [
|
||||||
|
'zeros', 'ones', 'rand', 'randn', 'stratified_mean',
|
||||||
|
'stratified_random'
|
||||||
|
]
|
||||||
self.x = torch.tensor(
|
self.x = torch.tensor(
|
||||||
[[0, -1, -2], [10, 11, 12], [0, 0, 0], [2, 2, 2]],
|
[[0, -1, -2], [10, 11, 12], [0, 0, 0], [2, 2, 2]],
|
||||||
dtype=torch.float32)
|
dtype=torch.float32)
|
||||||
@@ -274,11 +324,7 @@ class TestInitializers(unittest.TestCase):
|
|||||||
self.assertIsNotNone(initializers.INITIALIZERS)
|
self.assertIsNotNone(initializers.INITIALIZERS)
|
||||||
|
|
||||||
def test_funcname_deserialization(self):
|
def test_funcname_deserialization(self):
|
||||||
flist = [
|
for funcname in self.flist:
|
||||||
'zeros', 'ones', 'rand', 'randn', 'stratified_mean',
|
|
||||||
'stratified_random'
|
|
||||||
]
|
|
||||||
for funcname in flist:
|
|
||||||
f = initializers.get_initializer(funcname)
|
f = initializers.get_initializer(funcname)
|
||||||
iscallable = callable(f)
|
iscallable = callable(f)
|
||||||
self.assertTrue(iscallable)
|
self.assertTrue(iscallable)
|
||||||
@@ -336,7 +382,7 @@ class TestInitializers(unittest.TestCase):
|
|||||||
|
|
||||||
def test_stratified_mean_equal1(self):
|
def test_stratified_mean_equal1(self):
|
||||||
pdist = torch.tensor([1, 1])
|
pdist = torch.tensor([1, 1])
|
||||||
actual, _ = initializers.stratified_mean(self.x, self.y, pdist)
|
actual, _ = initializers.stratified_mean(self.x, self.y, pdist, False)
|
||||||
desired = torch.tensor([[5., 5., 5.], [1., 1., 1.]])
|
desired = torch.tensor([[5., 5., 5.], [1., 1., 1.]])
|
||||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||||
desired,
|
desired,
|
||||||
@@ -345,8 +391,9 @@ class TestInitializers(unittest.TestCase):
|
|||||||
|
|
||||||
def test_stratified_random_equal1(self):
|
def test_stratified_random_equal1(self):
|
||||||
pdist = torch.tensor([1, 1])
|
pdist = torch.tensor([1, 1])
|
||||||
actual, _ = initializers.stratified_random(self.x, self.y, pdist)
|
actual, _ = initializers.stratified_random(self.x, self.y, pdist,
|
||||||
desired = torch.tensor([[0., -1., -2.], [2., 2., 2.]])
|
False)
|
||||||
|
desired = torch.tensor([[0., -1., -2.], [0., 0., 0.]])
|
||||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||||
desired,
|
desired,
|
||||||
decimal=5)
|
decimal=5)
|
||||||
@@ -354,7 +401,7 @@ class TestInitializers(unittest.TestCase):
|
|||||||
|
|
||||||
def test_stratified_mean_equal2(self):
|
def test_stratified_mean_equal2(self):
|
||||||
pdist = torch.tensor([2, 2])
|
pdist = torch.tensor([2, 2])
|
||||||
actual, _ = initializers.stratified_mean(self.x, self.y, pdist)
|
actual, _ = initializers.stratified_mean(self.x, self.y, pdist, False)
|
||||||
desired = torch.tensor([[5., 5., 5.], [5., 5., 5.], [1., 1., 1.],
|
desired = torch.tensor([[5., 5., 5.], [5., 5., 5.], [1., 1., 1.],
|
||||||
[1., 1., 1.]])
|
[1., 1., 1.]])
|
||||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||||
@@ -362,9 +409,20 @@ class TestInitializers(unittest.TestCase):
|
|||||||
decimal=5)
|
decimal=5)
|
||||||
self.assertIsNone(mismatch)
|
self.assertIsNone(mismatch)
|
||||||
|
|
||||||
|
def test_stratified_random_equal2(self):
|
||||||
|
pdist = torch.tensor([2, 2])
|
||||||
|
actual, _ = initializers.stratified_random(self.x, self.y, pdist,
|
||||||
|
False)
|
||||||
|
desired = torch.tensor([[0., -1., -2.], [0., -1., -2.], [0., 0., 0.],
|
||||||
|
[0., 0., 0.]])
|
||||||
|
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||||
|
desired,
|
||||||
|
decimal=5)
|
||||||
|
self.assertIsNone(mismatch)
|
||||||
|
|
||||||
def test_stratified_mean_unequal(self):
|
def test_stratified_mean_unequal(self):
|
||||||
pdist = torch.tensor([1, 3])
|
pdist = torch.tensor([1, 3])
|
||||||
actual, _ = initializers.stratified_mean(self.x, self.y, pdist)
|
actual, _ = initializers.stratified_mean(self.x, self.y, pdist, False)
|
||||||
desired = torch.tensor([[5., 5., 5.], [1., 1., 1.], [1., 1., 1.],
|
desired = torch.tensor([[5., 5., 5.], [1., 1., 1.], [1., 1., 1.],
|
||||||
[1., 1., 1.]])
|
[1., 1., 1.]])
|
||||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||||
@@ -374,14 +432,86 @@ class TestInitializers(unittest.TestCase):
|
|||||||
|
|
||||||
def test_stratified_random_unequal(self):
|
def test_stratified_random_unequal(self):
|
||||||
pdist = torch.tensor([1, 3])
|
pdist = torch.tensor([1, 3])
|
||||||
actual, _ = initializers.stratified_random(self.x, self.y, pdist)
|
actual, _ = initializers.stratified_random(self.x, self.y, pdist,
|
||||||
desired = torch.tensor([[0., -1., -2.], [2., 2., 2.], [0., 0., 0.],
|
False)
|
||||||
|
desired = torch.tensor([[0., -1., -2.], [0., 0., 0.], [0., 0., 0.],
|
||||||
[0., 0., 0.]])
|
[0., 0., 0.]])
|
||||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||||
desired,
|
desired,
|
||||||
decimal=5)
|
decimal=5)
|
||||||
self.assertIsNone(mismatch)
|
self.assertIsNone(mismatch)
|
||||||
|
|
||||||
|
def test_stratified_mean_unequal_one_hot(self):
|
||||||
|
pdist = torch.tensor([1, 3])
|
||||||
|
y = torch.eye(2)[self.y]
|
||||||
|
desired1 = torch.tensor([[5., 5., 5.], [1., 1., 1.], [1., 1., 1.],
|
||||||
|
[1., 1., 1.]])
|
||||||
|
actual1, actual2 = initializers.stratified_mean(self.x, y, pdist)
|
||||||
|
desired2 = torch.tensor([[1, 0], [0, 1], [0, 1], [0, 1]])
|
||||||
|
mismatch = np.testing.assert_array_almost_equal(actual1,
|
||||||
|
desired1,
|
||||||
|
decimal=5)
|
||||||
|
mismatch = np.testing.assert_array_almost_equal(actual2,
|
||||||
|
desired2,
|
||||||
|
decimal=5)
|
||||||
|
self.assertIsNone(mismatch)
|
||||||
|
|
||||||
|
def test_stratified_random_unequal_one_hot(self):
|
||||||
|
pdist = torch.tensor([1, 3])
|
||||||
|
y = torch.eye(2)[self.y]
|
||||||
|
actual1, actual2 = initializers.stratified_random(self.x, y, pdist)
|
||||||
|
desired1 = torch.tensor([[0., -1., -2.], [0., 0., 0.], [0., 0., 0.],
|
||||||
|
[0., 0., 0.]])
|
||||||
|
desired2 = torch.tensor([[1, 0], [0, 1], [0, 1], [0, 1]])
|
||||||
|
mismatch = np.testing.assert_array_almost_equal(actual1,
|
||||||
|
desired1,
|
||||||
|
decimal=5)
|
||||||
|
mismatch = np.testing.assert_array_almost_equal(actual2,
|
||||||
|
desired2,
|
||||||
|
decimal=5)
|
||||||
|
self.assertIsNone(mismatch)
|
||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
del self.x, self.y, self.gen
|
del self.x, self.y, self.gen
|
||||||
_ = torch.seed()
|
_ = torch.seed()
|
||||||
|
|
||||||
|
|
||||||
|
class TestLosses(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def test_glvq_loss_int_labels(self):
|
||||||
|
d = torch.stack([torch.ones(100), torch.zeros(100)], dim=1)
|
||||||
|
labels = torch.tensor([0, 1])
|
||||||
|
targets = torch.ones(100)
|
||||||
|
batch_loss = losses.glvq_loss(distances=d,
|
||||||
|
target_labels=targets,
|
||||||
|
prototype_labels=labels)
|
||||||
|
loss_value = torch.sum(batch_loss, dim=0)
|
||||||
|
self.assertEqual(loss_value, -100)
|
||||||
|
|
||||||
|
def test_glvq_loss_one_hot_labels(self):
|
||||||
|
d = torch.stack([torch.ones(100), torch.zeros(100)], dim=1)
|
||||||
|
labels = torch.tensor([[0, 1], [1, 0]])
|
||||||
|
wl = torch.tensor([1, 0])
|
||||||
|
targets = torch.stack([wl for _ in range(100)], dim=0)
|
||||||
|
batch_loss = losses.glvq_loss(distances=d,
|
||||||
|
target_labels=targets,
|
||||||
|
prototype_labels=labels)
|
||||||
|
loss_value = torch.sum(batch_loss, dim=0)
|
||||||
|
self.assertEqual(loss_value, -100)
|
||||||
|
|
||||||
|
def test_glvq_loss_one_hot_unequal(self):
|
||||||
|
dlist = [torch.ones(100), torch.zeros(100), torch.zeros(100)]
|
||||||
|
d = torch.stack(dlist, dim=1)
|
||||||
|
labels = torch.tensor([[0, 1], [1, 0], [1, 0]])
|
||||||
|
wl = torch.tensor([1, 0])
|
||||||
|
targets = torch.stack([wl for _ in range(100)], dim=0)
|
||||||
|
batch_loss = losses.glvq_loss(distances=d,
|
||||||
|
target_labels=targets,
|
||||||
|
prototype_labels=labels)
|
||||||
|
loss_value = torch.sum(batch_loss, dim=0)
|
||||||
|
self.assertEqual(loss_value, -100)
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
pass
|
||||||
|
@@ -5,7 +5,7 @@ import unittest
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from prototorch.modules import prototypes, losses
|
from prototorch.modules import losses, prototypes
|
||||||
|
|
||||||
|
|
||||||
class TestPrototypes(unittest.TestCase):
|
class TestPrototypes(unittest.TestCase):
|
||||||
@@ -16,19 +16,23 @@ class TestPrototypes(unittest.TestCase):
|
|||||||
self.y = torch.tensor([0, 0, 1, 1])
|
self.y = torch.tensor([0, 0, 1, 1])
|
||||||
self.gen = torch.manual_seed(42)
|
self.gen = torch.manual_seed(42)
|
||||||
|
|
||||||
def test_addprototypes1d_init_without_input_dim(self):
|
def test_prototypes1d_init_without_input_dim(self):
|
||||||
with self.assertRaises(NameError):
|
with self.assertRaises(NameError):
|
||||||
_ = prototypes.AddPrototypes1D(nclasses=1)
|
_ = prototypes.Prototypes1D(nclasses=2)
|
||||||
|
|
||||||
def test_addprototypes1d_init_without_nclasses(self):
|
def test_prototypes1d_init_without_nclasses(self):
|
||||||
with self.assertRaises(NameError):
|
with self.assertRaises(NameError):
|
||||||
_ = prototypes.AddPrototypes1D(input_dim=1)
|
_ = prototypes.Prototypes1D(input_dim=1)
|
||||||
|
|
||||||
def test_addprototypes1d_init_without_pdist(self):
|
def test_prototypes1d_init_with_nclasses_1(self):
|
||||||
p1 = prototypes.AddPrototypes1D(input_dim=6,
|
with self.assertWarns(UserWarning):
|
||||||
nclasses=2,
|
_ = prototypes.Prototypes1D(nclasses=1, input_dim=1)
|
||||||
prototypes_per_class=4,
|
|
||||||
prototype_initializer='ones')
|
def test_prototypes1d_init_without_pdist(self):
|
||||||
|
p1 = prototypes.Prototypes1D(input_dim=6,
|
||||||
|
nclasses=2,
|
||||||
|
prototypes_per_class=4,
|
||||||
|
prototype_initializer='ones')
|
||||||
protos = p1.prototypes
|
protos = p1.prototypes
|
||||||
actual = protos.detach().numpy()
|
actual = protos.detach().numpy()
|
||||||
desired = torch.ones(8, 6)
|
desired = torch.ones(8, 6)
|
||||||
@@ -37,11 +41,11 @@ class TestPrototypes(unittest.TestCase):
|
|||||||
decimal=5)
|
decimal=5)
|
||||||
self.assertIsNone(mismatch)
|
self.assertIsNone(mismatch)
|
||||||
|
|
||||||
def test_addprototypes1d_init_without_data(self):
|
def test_prototypes1d_init_without_data(self):
|
||||||
pdist = [2, 2]
|
pdist = [2, 2]
|
||||||
p1 = prototypes.AddPrototypes1D(input_dim=3,
|
p1 = prototypes.Prototypes1D(input_dim=3,
|
||||||
prototype_distribution=pdist,
|
prototype_distribution=pdist,
|
||||||
prototype_initializer='zeros')
|
prototype_initializer='zeros')
|
||||||
protos = p1.prototypes
|
protos = p1.prototypes
|
||||||
actual = protos.detach().numpy()
|
actual = protos.detach().numpy()
|
||||||
desired = torch.zeros(4, 3)
|
desired = torch.zeros(4, 3)
|
||||||
@@ -50,23 +54,20 @@ class TestPrototypes(unittest.TestCase):
|
|||||||
decimal=5)
|
decimal=5)
|
||||||
self.assertIsNone(mismatch)
|
self.assertIsNone(mismatch)
|
||||||
|
|
||||||
# def test_addprototypes1d_init_torch_pdist(self):
|
def test_prototypes1d_proto_init_without_data(self):
|
||||||
# pdist = torch.tensor([2, 2])
|
with self.assertWarns(UserWarning):
|
||||||
# p1 = prototypes.AddPrototypes1D(input_dim=3,
|
_ = prototypes.Prototypes1D(
|
||||||
# prototype_distribution=pdist,
|
input_dim=3,
|
||||||
# prototype_initializer='zeros')
|
nclasses=2,
|
||||||
# protos = p1.prototypes
|
prototypes_per_class=1,
|
||||||
# actual = protos.detach().numpy()
|
prototype_initializer='stratified_mean',
|
||||||
# desired = torch.zeros(4, 3)
|
data=None)
|
||||||
# mismatch = np.testing.assert_array_almost_equal(actual,
|
|
||||||
# desired,
|
|
||||||
# decimal=5)
|
|
||||||
# self.assertIsNone(mismatch)
|
|
||||||
|
|
||||||
def test_addprototypes1d_init_with_ppc(self):
|
def test_prototypes1d_init_torch_pdist(self):
|
||||||
p1 = prototypes.AddPrototypes1D(data=[self.x, self.y],
|
pdist = torch.tensor([2, 2])
|
||||||
prototypes_per_class=2,
|
p1 = prototypes.Prototypes1D(input_dim=3,
|
||||||
prototype_initializer='zeros')
|
prototype_distribution=pdist,
|
||||||
|
prototype_initializer='zeros')
|
||||||
protos = p1.prototypes
|
protos = p1.prototypes
|
||||||
actual = protos.detach().numpy()
|
actual = protos.detach().numpy()
|
||||||
desired = torch.zeros(4, 3)
|
desired = torch.zeros(4, 3)
|
||||||
@@ -75,10 +76,119 @@ class TestPrototypes(unittest.TestCase):
|
|||||||
decimal=5)
|
decimal=5)
|
||||||
self.assertIsNone(mismatch)
|
self.assertIsNone(mismatch)
|
||||||
|
|
||||||
def test_addprototypes1d_init_with_pdist(self):
|
def test_prototypes1d_init_without_inputdim_with_data(self):
|
||||||
p1 = prototypes.AddPrototypes1D(data=[self.x, self.y],
|
_ = prototypes.Prototypes1D(nclasses=2,
|
||||||
prototype_distribution=[6, 9],
|
prototypes_per_class=1,
|
||||||
prototype_initializer='zeros')
|
prototype_initializer='stratified_mean',
|
||||||
|
data=[[[1.], [0.]], [1, 0]])
|
||||||
|
|
||||||
|
def test_prototypes1d_init_with_int_data(self):
|
||||||
|
_ = prototypes.Prototypes1D(nclasses=2,
|
||||||
|
prototypes_per_class=1,
|
||||||
|
prototype_initializer='stratified_mean',
|
||||||
|
data=[[[1], [0]], [1, 0]])
|
||||||
|
|
||||||
|
def test_prototypes1d_init_one_hot_without_data(self):
|
||||||
|
_ = prototypes.Prototypes1D(input_dim=1,
|
||||||
|
nclasses=2,
|
||||||
|
prototypes_per_class=1,
|
||||||
|
prototype_initializer='stratified_mean',
|
||||||
|
data=None,
|
||||||
|
one_hot_labels=True)
|
||||||
|
|
||||||
|
def test_prototypes1d_init_one_hot_labels_false(self):
|
||||||
|
"""Test if ValueError is raised when `one_hot_labels` is set to `False`
|
||||||
|
but the provided `data` has one-hot encoded labels.
|
||||||
|
"""
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
_ = prototypes.Prototypes1D(
|
||||||
|
input_dim=1,
|
||||||
|
nclasses=2,
|
||||||
|
prototypes_per_class=1,
|
||||||
|
prototype_initializer='stratified_mean',
|
||||||
|
data=([[0.], [1.]], [[0, 1], [1, 0]]),
|
||||||
|
one_hot_labels=False)
|
||||||
|
|
||||||
|
def test_prototypes1d_init_1d_y_data_one_hot_labels_true(self):
|
||||||
|
"""Test if ValueError is raised when `one_hot_labels` is set to `True`
|
||||||
|
but the provided `data` does not contain one-hot encoded labels.
|
||||||
|
"""
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
_ = prototypes.Prototypes1D(
|
||||||
|
input_dim=1,
|
||||||
|
nclasses=2,
|
||||||
|
prototypes_per_class=1,
|
||||||
|
prototype_initializer='stratified_mean',
|
||||||
|
data=([[0.], [1.]], [0, 1]),
|
||||||
|
one_hot_labels=True)
|
||||||
|
|
||||||
|
def test_prototypes1d_init_one_hot_labels_true(self):
|
||||||
|
"""Test if ValueError is raised when `one_hot_labels` is set to `True`
|
||||||
|
but the provided `data` contains 2D targets but
|
||||||
|
does not contain one-hot encoded labels.
|
||||||
|
"""
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
_ = prototypes.Prototypes1D(
|
||||||
|
input_dim=1,
|
||||||
|
nclasses=2,
|
||||||
|
prototypes_per_class=1,
|
||||||
|
prototype_initializer='stratified_mean',
|
||||||
|
data=([[0.], [1.]], [[0], [1]]),
|
||||||
|
one_hot_labels=True)
|
||||||
|
|
||||||
|
def test_prototypes1d_init_with_int_dtype(self):
|
||||||
|
with self.assertRaises(RuntimeError):
|
||||||
|
_ = prototypes.Prototypes1D(
|
||||||
|
nclasses=2,
|
||||||
|
prototypes_per_class=1,
|
||||||
|
prototype_initializer='stratified_mean',
|
||||||
|
data=[[[1], [0]], [1, 0]],
|
||||||
|
dtype=torch.int32)
|
||||||
|
|
||||||
|
def test_prototypes1d_inputndim_with_data(self):
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
_ = prototypes.Prototypes1D(input_dim=1,
|
||||||
|
nclasses=1,
|
||||||
|
prototypes_per_class=1,
|
||||||
|
data=[[1.], [1]])
|
||||||
|
|
||||||
|
def test_prototypes1d_inputdim_with_data(self):
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
_ = prototypes.Prototypes1D(
|
||||||
|
input_dim=2,
|
||||||
|
nclasses=2,
|
||||||
|
prototypes_per_class=1,
|
||||||
|
prototype_initializer='stratified_mean',
|
||||||
|
data=[[[1.], [0.]], [1, 0]])
|
||||||
|
|
||||||
|
def test_prototypes1d_nclasses_with_data(self):
|
||||||
|
"""Test ValueError raise if provided `nclasses` is not the same
|
||||||
|
as the one computed from the provided `data`.
|
||||||
|
"""
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
_ = prototypes.Prototypes1D(
|
||||||
|
input_dim=1,
|
||||||
|
nclasses=1,
|
||||||
|
prototypes_per_class=1,
|
||||||
|
prototype_initializer='stratified_mean',
|
||||||
|
data=[[[1.], [2.]], [1, 2]])
|
||||||
|
|
||||||
|
def test_prototypes1d_init_with_ppc(self):
|
||||||
|
p1 = prototypes.Prototypes1D(data=[self.x, self.y],
|
||||||
|
prototypes_per_class=2,
|
||||||
|
prototype_initializer='zeros')
|
||||||
|
protos = p1.prototypes
|
||||||
|
actual = protos.detach().numpy()
|
||||||
|
desired = torch.zeros(4, 3)
|
||||||
|
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||||
|
desired,
|
||||||
|
decimal=5)
|
||||||
|
self.assertIsNone(mismatch)
|
||||||
|
|
||||||
|
def test_prototypes1d_init_with_pdist(self):
|
||||||
|
p1 = prototypes.Prototypes1D(data=[self.x, self.y],
|
||||||
|
prototype_distribution=[6, 9],
|
||||||
|
prototype_initializer='zeros')
|
||||||
protos = p1.prototypes
|
protos = p1.prototypes
|
||||||
actual = protos.detach().numpy()
|
actual = protos.detach().numpy()
|
||||||
desired = torch.zeros(15, 3)
|
desired = torch.zeros(15, 3)
|
||||||
@@ -87,14 +197,14 @@ class TestPrototypes(unittest.TestCase):
|
|||||||
decimal=5)
|
decimal=5)
|
||||||
self.assertIsNone(mismatch)
|
self.assertIsNone(mismatch)
|
||||||
|
|
||||||
def test_addprototypes1d_func_initializer(self):
|
def test_prototypes1d_func_initializer(self):
|
||||||
def my_initializer(*args, **kwargs):
|
def my_initializer(*args, **kwargs):
|
||||||
return torch.full((2, 99), 99), torch.tensor([0, 1])
|
return torch.full((2, 99), 99), torch.tensor([0, 1])
|
||||||
|
|
||||||
p1 = prototypes.AddPrototypes1D(input_dim=99,
|
p1 = prototypes.Prototypes1D(input_dim=99,
|
||||||
nclasses=2,
|
nclasses=2,
|
||||||
prototypes_per_class=1,
|
prototypes_per_class=1,
|
||||||
prototype_initializer=my_initializer)
|
prototype_initializer=my_initializer)
|
||||||
protos = p1.prototypes
|
protos = p1.prototypes
|
||||||
actual = protos.detach().numpy()
|
actual = protos.detach().numpy()
|
||||||
desired = 99 * torch.ones(2, 99)
|
desired = 99 * torch.ones(2, 99)
|
||||||
@@ -103,8 +213,8 @@ class TestPrototypes(unittest.TestCase):
|
|||||||
decimal=5)
|
decimal=5)
|
||||||
self.assertIsNone(mismatch)
|
self.assertIsNone(mismatch)
|
||||||
|
|
||||||
def test_addprototypes1d_forward(self):
|
def test_prototypes1d_forward(self):
|
||||||
p1 = prototypes.AddPrototypes1D(data=[self.x, self.y])
|
p1 = prototypes.Prototypes1D(data=[self.x, self.y])
|
||||||
protos, _ = p1()
|
protos, _ = p1()
|
||||||
actual = protos.detach().numpy()
|
actual = protos.detach().numpy()
|
||||||
desired = torch.ones(2, 3)
|
desired = torch.ones(2, 3)
|
||||||
@@ -113,6 +223,16 @@ class TestPrototypes(unittest.TestCase):
|
|||||||
decimal=5)
|
decimal=5)
|
||||||
self.assertIsNone(mismatch)
|
self.assertIsNone(mismatch)
|
||||||
|
|
||||||
|
def test_prototypes1d_dist_validate(self):
|
||||||
|
p1 = prototypes.Prototypes1D(input_dim=0, prototype_distribution=[0])
|
||||||
|
with self.assertWarns(UserWarning):
|
||||||
|
_ = p1._validate_prototype_distribution()
|
||||||
|
|
||||||
|
def test_prototypes1d_validate_extra_repr_not_empty(self):
|
||||||
|
p1 = prototypes.Prototypes1D(input_dim=0, prototype_distribution=[0])
|
||||||
|
rep = p1.extra_repr()
|
||||||
|
self.assertNotEqual(rep, '')
|
||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
del self.x, self.y, self.gen
|
del self.x, self.y, self.gen
|
||||||
_ = torch.seed()
|
_ = torch.seed()
|
||||||
@@ -123,7 +243,19 @@ class TestLosses(unittest.TestCase):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
def test_glvqloss_init(self):
|
def test_glvqloss_init(self):
|
||||||
_ = losses.GLVQLoss()
|
_ = losses.GLVQLoss(0, 'swish_beta', beta=20)
|
||||||
|
|
||||||
|
def test_glvqloss_forward(self):
|
||||||
|
criterion = losses.GLVQLoss(margin=0,
|
||||||
|
squashing='sigmoid_beta',
|
||||||
|
beta=100)
|
||||||
|
d = torch.stack([torch.ones(100), torch.zeros(100)], dim=1)
|
||||||
|
labels = torch.tensor([0, 1])
|
||||||
|
targets = torch.ones(100)
|
||||||
|
outputs = [d, labels]
|
||||||
|
loss = criterion(outputs, targets)
|
||||||
|
loss_value = loss.item()
|
||||||
|
self.assertAlmostEqual(loss_value, 0.0)
|
||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
pass
|
pass
|
||||||
|
10
tox.ini
10
tox.ini
@@ -4,12 +4,12 @@
|
|||||||
# and then run "tox" from this directory.
|
# and then run "tox" from this directory.
|
||||||
|
|
||||||
[tox]
|
[tox]
|
||||||
envlist = py36
|
envlist = py36,py37,py38
|
||||||
|
|
||||||
[testenv]
|
[testenv]
|
||||||
deps =
|
deps =
|
||||||
numpy
|
pytest
|
||||||
unittest-xml-reporting
|
coverage
|
||||||
commands =
|
commands =
|
||||||
python -m xmlrunner -o reports
|
pip install -e .
|
||||||
|
coverage run -m pytest
|
Reference in New Issue
Block a user