Compare commits
36 Commits
v0.1.1-dev
...
v0.1.1-rc0
Author | SHA1 | Date | |
---|---|---|---|
|
0cfbc0473b | ||
|
cf0659d881 | ||
|
d17b9a3346 | ||
|
532f63b1de | ||
|
c11a3860df | ||
|
dab91e471a | ||
|
a167565857 | ||
|
e063625486 | ||
|
89eb5358a0 | ||
|
5c59515128 | ||
|
7eb7a6b194 | ||
|
5811c4b9f9 | ||
|
7b1887d56e | ||
|
63a25e7a38 | ||
|
a0f20a40f6 | ||
|
88cbe0a126 | ||
|
a3548e0ddd | ||
|
3cfbc49254 | ||
|
2b82830590 | ||
|
553b1e1a65 | ||
|
a9d2855323 | ||
|
cf7d7b5d9d | ||
|
a22c752342 | ||
|
4158586cb9 | ||
|
f80d9648c3 | ||
|
e54bf07030 | ||
|
8c629c0cb1 | ||
|
8f3a43f62a | ||
|
955661af95 | ||
|
c54d14c55e | ||
|
6090aad176 | ||
|
1ec7bd261b | ||
|
da3b0cc262 | ||
|
f640a22cf2 | ||
|
c843ace63d | ||
|
242c9de3b6 |
@@ -1,5 +1,5 @@
|
||||
[bumpversion]
|
||||
current_version = 0.1.1-dev0
|
||||
current_version = 0.1.1-rc0
|
||||
commit = True
|
||||
tag = True
|
||||
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)(\-(?P<release>[a-z]+)(?P<build>\d+))?
|
||||
|
3
.github/workflows/pythonapp.yml
vendored
3
.github/workflows/pythonapp.yml
vendored
@@ -24,6 +24,9 @@ jobs:
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install .
|
||||
- name: Install extras
|
||||
run: |
|
||||
pip install -r requirements.txt
|
||||
- name: Lint with flake8
|
||||
run: |
|
||||
pip install flake8
|
||||
|
@@ -8,12 +8,11 @@ cache:
|
||||
|
||||
install:
|
||||
- pip install . --progress-bar off
|
||||
- pip install codecov
|
||||
- pip install pytest
|
||||
- pip install -r requirements.txt
|
||||
|
||||
script:
|
||||
- coverage run -m pytest
|
||||
|
||||
# Push the results to codecov
|
||||
after_success:
|
||||
- codecov
|
||||
- bash <(curl -s https://codecov.io/bash)
|
||||
|
11
README.md
11
README.md
@@ -4,11 +4,12 @@ ProtoTorch is a PyTorch-based Python toolbox for bleeding-edge research in
|
||||
prototype-based machine learning algorithms.
|
||||
|
||||
[](https://travis-ci.org/si-cim/prototorch)
|
||||
[](https://badge.fury.io/gh/si-cim%2Fprototorch)
|
||||
[](https://badge.fury.io/py/prototorch)
|
||||

|
||||
[](https://github.com/si-cim/prototorch/releases)
|
||||
[](https://pypi.org/project/prototorch/)
|
||||
[](https://codecov.io/gh/si-cim/prototorch)
|
||||
[](https://pepy.tech/project/prototorch)
|
||||
[](https://www.codacy.com/gh/si-cim/prototorch?utm_source=github.com&utm_medium=referral&utm_content=si-cim/prototorch&utm_campaign=Badge_Grade)
|
||||

|
||||
[](https://github.com/si-cim/prototorch/blob/master/LICENSE)
|
||||
|
||||
## Description
|
||||
@@ -27,12 +28,12 @@ provided by PyTorch.
|
||||
## Installation
|
||||
|
||||
ProtoTorch can be installed using `pip`.
|
||||
```
|
||||
```bash
|
||||
pip install prototorch
|
||||
```
|
||||
|
||||
To install the bleeding-edge features and improvements:
|
||||
```
|
||||
```bash
|
||||
git clone https://github.com/si-cim/prototorch.git
|
||||
git checkout dev
|
||||
cd prototorch
|
||||
|
10
RELEASE.md
10
RELEASE.md
@@ -1,3 +1,11 @@
|
||||
# Release 0.1.0-dev0
|
||||
# ProtoTorch Releases
|
||||
|
||||
## Release 0.1.1-dev0
|
||||
|
||||
### Includes
|
||||
- Minor bugfixes.
|
||||
- 100% line coverage.
|
||||
|
||||
## Release 0.1.0-dev0
|
||||
|
||||
Initial public release of ProtoTorch.
|
||||
|
@@ -1,4 +1,4 @@
|
||||
"""ProtoTorch GLVQ example using 2D Iris data"""
|
||||
"""ProtoTorch GLVQ example using 2D Iris data."""
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -8,7 +8,7 @@ from sklearn.preprocessing import StandardScaler
|
||||
|
||||
from prototorch.functions.distances import euclidean_distance
|
||||
from prototorch.modules.losses import GLVQLoss
|
||||
from prototorch.modules.prototypes import AddPrototypes1D
|
||||
from prototorch.modules.prototypes import Prototypes1D
|
||||
|
||||
# Prepare and preprocess the data
|
||||
scaler = StandardScaler()
|
||||
@@ -21,8 +21,9 @@ x_train = scaler.transform(x_train)
|
||||
# Define the GLVQ model
|
||||
class Model(torch.nn.Module):
|
||||
def __init__(self, **kwargs):
|
||||
"""GLVQ model."""
|
||||
super().__init__()
|
||||
self.p1 = AddPrototypes1D(input_dim=2,
|
||||
self.p1 = Prototypes1D(input_dim=2,
|
||||
prototypes_per_class=1,
|
||||
nclasses=3,
|
||||
prototype_initializer='zeros')
|
||||
@@ -41,13 +42,17 @@ model = Model()
|
||||
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
|
||||
criterion = GLVQLoss(squashing='sigmoid_beta', beta=10)
|
||||
|
||||
x_in = torch.Tensor(x_train)
|
||||
y_in = torch.Tensor(y_train)
|
||||
|
||||
# Training loop
|
||||
fig = plt.figure('Prototype Visualization')
|
||||
title = 'Prototype Visualization'
|
||||
fig = plt.figure(title)
|
||||
for epoch in range(70):
|
||||
# Compute loss.
|
||||
distances, plabels = model(torch.tensor(x_train))
|
||||
loss = criterion([distances, plabels], torch.tensor(y_train))
|
||||
print(f'Epoch: {epoch + 1:03d} Loss: {loss.item():02.02f}')
|
||||
# Compute loss
|
||||
dis, plabels = model(x_in)
|
||||
loss = criterion([dis, plabels], y_in)
|
||||
print(f'Epoch: {epoch + 1:03d} Loss: {loss.item():05.02f}')
|
||||
|
||||
# Take a gradient descent step
|
||||
optimizer.zero_grad()
|
||||
@@ -60,6 +65,9 @@ for epoch in range(70):
|
||||
# Visualize the data and the prototypes
|
||||
ax = fig.gca()
|
||||
ax.cla()
|
||||
ax.set_title(title)
|
||||
ax.set_xlabel('Data dimension 1')
|
||||
ax.set_ylabel('Data dimension 2')
|
||||
cmap = 'viridis'
|
||||
ax.scatter(x_train[:, 0], x_train[:, 1], c=y_train, edgecolor='k')
|
||||
ax.scatter(protos[:, 0],
|
||||
@@ -71,28 +79,17 @@ for epoch in range(70):
|
||||
s=50)
|
||||
|
||||
# Paint decision regions
|
||||
border = 1
|
||||
resolution = 50
|
||||
x = np.vstack((x_train, protos))
|
||||
x_min, x_max = x[:, 0].min(), x[:, 0].max()
|
||||
y_min, y_max = x[:, 1].min(), x[:, 1].max()
|
||||
x_min, x_max = x_min - border, x_max + border
|
||||
y_min, y_max = y_min - border, y_max + border
|
||||
try:
|
||||
xx, yy = np.meshgrid(np.arange(x_min, x_max, 1.0 / resolution),
|
||||
np.arange(y_min, y_max, 1.0 / resolution))
|
||||
except ValueError as ve:
|
||||
print(ve)
|
||||
raise ValueError(f'x_min: {x_min}, x_max: {x_max}. '
|
||||
f'x_min - x_max is {x_max - x_min}.')
|
||||
except MemoryError as me:
|
||||
print(me)
|
||||
raise ValueError('Too many points. ' 'Try reducing the resolution.')
|
||||
x_min, x_max = x[:, 0].min() - 1, x[:, 0].max() + 1
|
||||
y_min, y_max = x[:, 1].min() - 1, x[:, 1].max() + 1
|
||||
xx, yy = np.meshgrid(np.arange(x_min, x_max, 1 / 50),
|
||||
np.arange(y_min, y_max, 1 / 50))
|
||||
mesh_input = np.c_[xx.ravel(), yy.ravel()]
|
||||
|
||||
torch_input = torch.from_numpy(mesh_input)
|
||||
torch_input = torch.Tensor(mesh_input)
|
||||
d = model(torch_input)[0]
|
||||
y_pred = np.argmin(d.detach().numpy(), axis=1)
|
||||
y_pred = np.argmin(d.detach().numpy(),
|
||||
axis=1) # assume one prototype per class
|
||||
y_pred = y_pred.reshape(xx.shape)
|
||||
|
||||
# Plot voronoi regions
|
||||
@@ -100,4 +97,5 @@ for epoch in range(70):
|
||||
|
||||
ax.set_xlim(left=x_min + 0, right=x_max - 0)
|
||||
ax.set_ylim(bottom=y_min + 0, top=y_max - 0)
|
||||
|
||||
plt.pause(0.1)
|
||||
|
@@ -1 +1,11 @@
|
||||
__version__ = '0.1.1-dev0'
|
||||
"""ProtoTorch package."""
|
||||
|
||||
__version__ = '0.1.1-rc0'
|
||||
|
||||
from prototorch import datasets, functions, modules
|
||||
|
||||
__all__ = [
|
||||
'datasets',
|
||||
'functions',
|
||||
'modules',
|
||||
]
|
||||
|
7
prototorch/datasets/__init__.py
Normal file
7
prototorch/datasets/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
"""ProtoTorch datasets."""
|
||||
|
||||
from .tecator import Tecator
|
||||
|
||||
__all__ = [
|
||||
'Tecator',
|
||||
]
|
87
prototorch/datasets/abstract.py
Normal file
87
prototorch/datasets/abstract.py
Normal file
@@ -0,0 +1,87 @@
|
||||
"""ProtoTorch abstract dataset classes.
|
||||
|
||||
Based on `torchvision.VisionDataset` and `torchvision.MNIST`
|
||||
|
||||
For the original code, see:
|
||||
https://github.com/pytorch/vision/blob/master/torchvision/datasets/vision.py
|
||||
https://github.com/pytorch/vision/blob/master/torchvision/datasets/mnist.py
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class Dataset(torch.utils.data.Dataset):
|
||||
"""Abstract dataset class to be inherited."""
|
||||
_repr_indent = 2
|
||||
|
||||
def __init__(self, root):
|
||||
if isinstance(root, torch._six.string_classes):
|
||||
root = os.path.expanduser(root)
|
||||
self.root = root
|
||||
|
||||
def __getitem__(self, index):
|
||||
raise NotImplementedError
|
||||
|
||||
def __len__(self):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class ProtoDataset(Dataset):
|
||||
"""Abstract dataset class to be inherited."""
|
||||
training_file = 'training.pt'
|
||||
test_file = 'test.pt'
|
||||
|
||||
def __init__(self, root, train=True, download=True, verbose=True):
|
||||
super().__init__(root)
|
||||
self.train = train # training set or test set
|
||||
self.verbose = verbose
|
||||
|
||||
if download:
|
||||
self.download()
|
||||
|
||||
if not self._check_exists():
|
||||
raise RuntimeError('Dataset not found. '
|
||||
'You can use download=True to download it')
|
||||
|
||||
data_file = self.training_file if self.train else self.test_file
|
||||
|
||||
self.data, self.targets = torch.load(
|
||||
os.path.join(self.processed_folder, data_file))
|
||||
|
||||
@property
|
||||
def raw_folder(self):
|
||||
return os.path.join(self.root, self.__class__.__name__, 'raw')
|
||||
|
||||
@property
|
||||
def processed_folder(self):
|
||||
return os.path.join(self.root, self.__class__.__name__, 'processed')
|
||||
|
||||
@property
|
||||
def class_to_idx(self):
|
||||
return {_class: i for i, _class in enumerate(self.classes)}
|
||||
|
||||
def _check_exists(self):
|
||||
return (os.path.exists(
|
||||
os.path.join(self.processed_folder, self.training_file))
|
||||
and os.path.exists(
|
||||
os.path.join(self.processed_folder, self.test_file)))
|
||||
|
||||
def __repr__(self):
|
||||
head = 'Dataset ' + self.__class__.__name__
|
||||
body = ['Number of datapoints: {}'.format(self.__len__())]
|
||||
if self.root is not None:
|
||||
body.append('Root location: {}'.format(self.root))
|
||||
body += self.extra_repr().splitlines()
|
||||
lines = [head] + [' ' * self._repr_indent + line for line in body]
|
||||
return '\n'.join(lines)
|
||||
|
||||
def extra_repr(self):
|
||||
return f"Split: {'Train' if self.train is True else 'Test'}"
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
def download(self):
|
||||
raise NotImplementedError
|
102
prototorch/datasets/tecator.py
Normal file
102
prototorch/datasets/tecator.py
Normal file
@@ -0,0 +1,102 @@
|
||||
"""Tecator dataset for classification.
|
||||
|
||||
URL:
|
||||
http://lib.stat.cmu.edu/datasets/tecator
|
||||
|
||||
LICENCE / TERMS / COPYRIGHT:
|
||||
This is the Tecator data set: The task is to predict the fat content
|
||||
of a meat sample on the basis of its near infrared absorbance spectrum.
|
||||
-------------------------------------------------------------------------
|
||||
1. Statement of permission from Tecator (the original data source)
|
||||
|
||||
These data are recorded on a Tecator Infratec Food and Feed Analyzer
|
||||
working in the wavelength range 850 - 1050 nm by the Near Infrared
|
||||
Transmission (NIT) principle. Each sample contains finely chopped pure
|
||||
meat with different moisture, fat and protein contents.
|
||||
|
||||
If results from these data are used in a publication we want you to
|
||||
mention the instrument and company name (Tecator) in the publication.
|
||||
In addition, please send a preprint of your article to
|
||||
|
||||
Karin Thente, Tecator AB,
|
||||
Box 70, S-263 21 Hoganas, Sweden
|
||||
|
||||
The data are available in the public domain with no responsability from
|
||||
the original data source. The data can be redistributed as long as this
|
||||
permission note is attached.
|
||||
|
||||
For more information about the instrument - call Perstorp Analytical's
|
||||
representative in your area.
|
||||
|
||||
Description:
|
||||
For each meat sample the data consists of a 100 channel spectrum of
|
||||
absorbances and the contents of moisture (water), fat and protein.
|
||||
The absorbance is -log10 of the transmittance
|
||||
measured by the spectrometer. The three contents, measured in percent,
|
||||
are determined by analytic chemistry.
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torchvision.datasets.utils import download_file_from_google_drive
|
||||
|
||||
from prototorch.datasets.abstract import ProtoDataset
|
||||
|
||||
|
||||
class Tecator(ProtoDataset):
|
||||
"""Tecator dataset for classification."""
|
||||
resources = [
|
||||
('1MMuUK8V41IgNpnPDbg3E-QAL6wlErTk0',
|
||||
'ba5607c580d0f91bb27dc29d13c2f8df'),
|
||||
] # (google_storage_id, md5hash)
|
||||
classes = ['0 - low_fat', '1 - high_fat']
|
||||
|
||||
def __getitem__(self, index):
|
||||
img, target = self.data[index], int(self.targets[index])
|
||||
return img, target
|
||||
|
||||
def download(self):
|
||||
"""Download the data if it doesn't exist in already."""
|
||||
if self._check_exists():
|
||||
return
|
||||
|
||||
if self.verbose:
|
||||
print('Making directories...')
|
||||
os.makedirs(self.raw_folder, exist_ok=True)
|
||||
os.makedirs(self.processed_folder, exist_ok=True)
|
||||
|
||||
if self.verbose:
|
||||
print('Downloading...')
|
||||
for fileid, md5 in self.resources:
|
||||
filename = 'tecator.npz'
|
||||
download_file_from_google_drive(fileid,
|
||||
root=self.raw_folder,
|
||||
filename=filename,
|
||||
md5=md5)
|
||||
|
||||
if self.verbose:
|
||||
print('Processing...')
|
||||
with np.load(os.path.join(self.raw_folder, 'tecator.npz'),
|
||||
allow_pickle=False) as f:
|
||||
x_train, y_train = f['x_train'], f['y_train']
|
||||
x_test, y_test = f['x_test'], f['y_test']
|
||||
training_set = [
|
||||
torch.tensor(x_train, dtype=torch.float32),
|
||||
torch.tensor(y_train),
|
||||
]
|
||||
test_set = [
|
||||
torch.tensor(x_test, dtype=torch.float32),
|
||||
torch.tensor(y_test),
|
||||
]
|
||||
|
||||
with open(os.path.join(self.processed_folder, self.training_file),
|
||||
'wb') as f:
|
||||
torch.save(training_set, f)
|
||||
with open(os.path.join(self.processed_folder, self.test_file),
|
||||
'wb') as f:
|
||||
torch.save(test_set, f)
|
||||
|
||||
if self.verbose:
|
||||
print('Done!')
|
@@ -0,0 +1,12 @@
|
||||
"""ProtoTorch functions."""
|
||||
|
||||
from .activations import identity, sigmoid_beta, swish_beta
|
||||
from .competitions import knnc, wtac
|
||||
|
||||
__all__ = [
|
||||
'identity',
|
||||
'sigmoid_beta',
|
||||
'swish_beta',
|
||||
'knnc',
|
||||
'wtac',
|
||||
]
|
||||
|
@@ -8,47 +8,57 @@ ACTIVATIONS = dict()
|
||||
# def register_activation(scriptf):
|
||||
# ACTIVATIONS[scriptf.name] = scriptf
|
||||
# return scriptf
|
||||
def register_activation(f):
|
||||
ACTIVATIONS[f.__name__] = f
|
||||
return f
|
||||
def register_activation(function):
|
||||
"""Add the activation function to the registry."""
|
||||
ACTIVATIONS[function.__name__] = function
|
||||
return function
|
||||
|
||||
|
||||
@register_activation
|
||||
# @torch.jit.script
|
||||
def identity(input, beta=torch.tensor([0])):
|
||||
""":math:`f(x) = x`"""
|
||||
return input
|
||||
def identity(x, beta=torch.tensor(0)):
|
||||
"""Identity activation function.
|
||||
|
||||
Definition:
|
||||
:math:`f(x) = x`
|
||||
"""
|
||||
return x
|
||||
|
||||
|
||||
@register_activation
|
||||
# @torch.jit.script
|
||||
def sigmoid_beta(input, beta=torch.tensor([10])):
|
||||
""":math:`f(x) = \\frac{1}{1 + e^{-\\beta x}}`
|
||||
def sigmoid_beta(x, beta=torch.tensor(10)):
|
||||
r"""Sigmoid activation function with scaling.
|
||||
|
||||
Definition:
|
||||
:math:`f(x) = \frac{1}{1 + e^{-\beta x}}`
|
||||
|
||||
Keyword Arguments:
|
||||
beta (float): Parameter :math:`\\beta`
|
||||
beta (`torch.tensor`): Scaling parameter :math:`\beta`
|
||||
"""
|
||||
out = torch.reciprocal(1.0 + torch.exp(-int(beta.item()) * input))
|
||||
out = torch.reciprocal(1.0 + torch.exp(-int(beta.item()) * x))
|
||||
return out
|
||||
|
||||
|
||||
@register_activation
|
||||
# @torch.jit.script
|
||||
def swish_beta(input, beta=torch.tensor([10])):
|
||||
""":math:`f(x) = \\frac{x}{1 + e^{-\\beta x}}`
|
||||
def swish_beta(x, beta=torch.tensor(10)):
|
||||
r"""Swish activation function with scaling.
|
||||
|
||||
Definition:
|
||||
:math:`f(x) = \frac{x}{1 + e^{-\beta x}}`
|
||||
|
||||
Keyword Arguments:
|
||||
beta (float): Parameter :math:`\\beta`
|
||||
beta (`torch.tensor`): Scaling parameter :math:`\beta`
|
||||
"""
|
||||
out = input * sigmoid_beta(input, beta=beta)
|
||||
out = x * sigmoid_beta(x, beta=beta)
|
||||
return out
|
||||
|
||||
|
||||
def get_activation(funcname):
|
||||
"""Deserialize the activation function."""
|
||||
if callable(funcname):
|
||||
return funcname
|
||||
else:
|
||||
if funcname in ACTIVATIONS:
|
||||
return ACTIVATIONS.get(funcname)
|
||||
else:
|
||||
raise NameError(f'Activation {funcname} was not found.')
|
||||
|
@@ -3,6 +3,34 @@
|
||||
import torch
|
||||
|
||||
|
||||
# @torch.jit.script
|
||||
def stratified_min(distances, labels):
|
||||
clabels = torch.unique(labels, dim=0)
|
||||
nclasses = clabels.size()[0]
|
||||
if distances.size()[1] == nclasses:
|
||||
# skip if only one prototype per class
|
||||
return distances
|
||||
batch_size = distances.size()[0]
|
||||
winning_distances = torch.zeros(nclasses, batch_size)
|
||||
inf = torch.full_like(distances.T, fill_value=float('inf'))
|
||||
# distances_to_wpluses = torch.where(matcher, distances, inf)
|
||||
for i, cl in enumerate(clabels):
|
||||
# cdists = distances.T[labels == cl]
|
||||
matcher = torch.eq(labels.unsqueeze(dim=1), cl)
|
||||
if labels.ndim == 2:
|
||||
# if the labels are one-hot vectors
|
||||
matcher = torch.eq(torch.sum(matcher, dim=-1), nclasses)
|
||||
cdists = torch.where(matcher, distances.T, inf).T
|
||||
winning_distances[i] = torch.min(cdists, dim=1,
|
||||
keepdim=True).values.squeeze()
|
||||
if labels.ndim == 2:
|
||||
# Transpose to return with `batch_size` first and
|
||||
# reverse the columns to fix the ordering of the classes
|
||||
return torch.flip(winning_distances.T, dims=(1, ))
|
||||
|
||||
return winning_distances.T # return with `batch_size` first
|
||||
|
||||
|
||||
# @torch.jit.script
|
||||
def wtac(distances, labels):
|
||||
winning_indices = torch.min(distances, dim=1).indices
|
||||
|
@@ -28,7 +28,7 @@ def euclidean_distance(x, y):
|
||||
|
||||
|
||||
def lpnorm_distance(x, y, p):
|
||||
"""Compute :math:`{\\langle x, y \\rangle}_p`.
|
||||
r"""Compute :math:`{\langle x, y \rangle}_p`.
|
||||
|
||||
Expected dimension of x is 2.
|
||||
Expected dimension of y is 2.
|
||||
@@ -38,9 +38,9 @@ def lpnorm_distance(x, y, p):
|
||||
|
||||
|
||||
def omega_distance(x, y, omega):
|
||||
"""Omega distance.
|
||||
r"""Omega distance.
|
||||
|
||||
Compute :math:`{\\langle \\Omega x, \\Omega y \\rangle}_p`
|
||||
Compute :math:`{\langle \Omega x, \Omega y \rangle}_p`
|
||||
|
||||
Expected dimension of x is 2.
|
||||
Expected dimension of y is 2.
|
||||
@@ -53,9 +53,9 @@ def omega_distance(x, y, omega):
|
||||
|
||||
|
||||
def lomega_distance(x, y, omegas):
|
||||
"""Localized Omega distance.
|
||||
r"""Localized Omega distance.
|
||||
|
||||
Compute :math:`{\\langle \\Omega_k x, \\Omega_k y_k \\rangle}_p`
|
||||
Compute :math:`{\langle \Omega_k x, \Omega_k y_k \rangle}_p`
|
||||
|
||||
Expected dimension of x is 2.
|
||||
Expected dimension of y is 2.
|
||||
@@ -69,3 +69,7 @@ def lomega_distance(x, y, omegas):
|
||||
distances = torch.sum(differences_squared, dim=2)
|
||||
distances = distances.permute(1, 0)
|
||||
return distances
|
||||
|
||||
|
||||
# Aliases
|
||||
sed = squared_euclidean_distance
|
||||
|
@@ -7,87 +7,97 @@ import torch
|
||||
INITIALIZERS = dict()
|
||||
|
||||
|
||||
def register_initializer(func):
|
||||
INITIALIZERS[func.__name__] = func
|
||||
return func
|
||||
def register_initializer(function):
|
||||
"""Add the initializer to the registry."""
|
||||
INITIALIZERS[function.__name__] = function
|
||||
return function
|
||||
|
||||
|
||||
def labels_from(distribution):
|
||||
def labels_from(distribution, one_hot=True):
|
||||
"""Takes a distribution tensor and returns a labels tensor."""
|
||||
nclasses = distribution.shape[0]
|
||||
llist = [[i] * n for i, n in zip(range(nclasses), distribution)]
|
||||
# labels = [l for cl in llist for l in cl] # flatten the list of lists
|
||||
labels = list(chain(*llist)) # flatten using itertools.chain
|
||||
return torch.tensor(labels, requires_grad=False)
|
||||
flat_llist = list(chain(*llist)) # flatten label list with itertools.chain
|
||||
plabels = torch.tensor(flat_llist, requires_grad=False)
|
||||
if one_hot:
|
||||
return torch.eye(nclasses)[plabels]
|
||||
return plabels
|
||||
|
||||
|
||||
@register_initializer
|
||||
def ones(x_train, y_train, prototype_distribution):
|
||||
def ones(x_train, y_train, prototype_distribution, one_hot=True):
|
||||
nprotos = torch.sum(prototype_distribution)
|
||||
protos = torch.ones(nprotos, *x_train.shape[1:])
|
||||
plabels = labels_from(prototype_distribution)
|
||||
plabels = labels_from(prototype_distribution, one_hot)
|
||||
return protos, plabels
|
||||
|
||||
|
||||
@register_initializer
|
||||
def zeros(x_train, y_train, prototype_distribution):
|
||||
def zeros(x_train, y_train, prototype_distribution, one_hot=True):
|
||||
nprotos = torch.sum(prototype_distribution)
|
||||
protos = torch.zeros(nprotos, *x_train.shape[1:])
|
||||
plabels = labels_from(prototype_distribution)
|
||||
plabels = labels_from(prototype_distribution, one_hot)
|
||||
return protos, plabels
|
||||
|
||||
|
||||
@register_initializer
|
||||
def rand(x_train, y_train, prototype_distribution):
|
||||
def rand(x_train, y_train, prototype_distribution, one_hot=True):
|
||||
nprotos = torch.sum(prototype_distribution)
|
||||
protos = torch.rand(nprotos, *x_train.shape[1:])
|
||||
plabels = labels_from(prototype_distribution)
|
||||
plabels = labels_from(prototype_distribution, one_hot)
|
||||
return protos, plabels
|
||||
|
||||
|
||||
@register_initializer
|
||||
def randn(x_train, y_train, prototype_distribution):
|
||||
def randn(x_train, y_train, prototype_distribution, one_hot=True):
|
||||
nprotos = torch.sum(prototype_distribution)
|
||||
protos = torch.randn(nprotos, *x_train.shape[1:])
|
||||
plabels = labels_from(prototype_distribution)
|
||||
plabels = labels_from(prototype_distribution, one_hot)
|
||||
return protos, plabels
|
||||
|
||||
|
||||
@register_initializer
|
||||
def stratified_mean(x_train, y_train, prototype_distribution):
|
||||
def stratified_mean(x_train, y_train, prototype_distribution, one_hot=True):
|
||||
nprotos = torch.sum(prototype_distribution)
|
||||
pdim = x_train.shape[1]
|
||||
protos = torch.empty(nprotos, pdim)
|
||||
plabels = labels_from(prototype_distribution)
|
||||
for i, l in enumerate(plabels):
|
||||
xl = x_train[y_train == l]
|
||||
plabels = labels_from(prototype_distribution, one_hot)
|
||||
for i, label in enumerate(plabels):
|
||||
matcher = torch.eq(label.unsqueeze(dim=0), y_train)
|
||||
if one_hot:
|
||||
nclasses = y_train.size()[1]
|
||||
matcher = torch.eq(torch.sum(matcher, dim=-1), nclasses)
|
||||
xl = x_train[matcher]
|
||||
mean_xl = torch.mean(xl, dim=0)
|
||||
protos[i] = mean_xl
|
||||
plabels = labels_from(prototype_distribution, one_hot=one_hot)
|
||||
return protos, plabels
|
||||
|
||||
|
||||
@register_initializer
|
||||
def stratified_random(x_train, y_train, prototype_distribution):
|
||||
gen = torch.manual_seed(torch.initial_seed())
|
||||
def stratified_random(x_train, y_train, prototype_distribution, one_hot=True):
|
||||
nprotos = torch.sum(prototype_distribution)
|
||||
pdim = x_train.shape[1]
|
||||
protos = torch.empty(nprotos, pdim)
|
||||
plabels = labels_from(prototype_distribution)
|
||||
for i, l in enumerate(plabels):
|
||||
xl = x_train[y_train == l]
|
||||
rand_index = torch.zeros(1).long().random_(0,
|
||||
xl.shape[1] - 1,
|
||||
generator=gen)
|
||||
plabels = labels_from(prototype_distribution, one_hot)
|
||||
for i, label in enumerate(plabels):
|
||||
matcher = torch.eq(label.unsqueeze(dim=0), y_train)
|
||||
if one_hot:
|
||||
nclasses = y_train.size()[1]
|
||||
matcher = torch.eq(torch.sum(matcher, dim=-1), nclasses)
|
||||
xl = x_train[matcher]
|
||||
rand_index = torch.zeros(1).long().random_(0, xl.shape[0] - 1)
|
||||
random_xl = xl[rand_index]
|
||||
protos[i] = random_xl
|
||||
plabels = labels_from(prototype_distribution, one_hot=one_hot)
|
||||
return protos, plabels
|
||||
|
||||
|
||||
def get_initializer(funcname):
|
||||
"""Deserialize the initializer."""
|
||||
if callable(funcname):
|
||||
return funcname
|
||||
else:
|
||||
if funcname in INITIALIZERS:
|
||||
return INITIALIZERS.get(funcname)
|
||||
else:
|
||||
raise NameError(f'Initializer {funcname} was not found.')
|
||||
|
@@ -3,20 +3,24 @@
|
||||
import torch
|
||||
|
||||
|
||||
def glvq_loss(distances, target_labels, prototype_labels):
|
||||
"""GLVQ loss function with support for one-hot labels."""
|
||||
matcher = torch.eq(target_labels.unsqueeze(dim=1), prototype_labels)
|
||||
if prototype_labels.ndim == 2:
|
||||
def _get_dp_dm(distances, targets, plabels):
|
||||
matcher = torch.eq(targets.unsqueeze(dim=1), plabels)
|
||||
if plabels.ndim == 2:
|
||||
# if the labels are one-hot vectors
|
||||
nclasses = target_labels.size()[1]
|
||||
nclasses = targets.size()[1]
|
||||
matcher = torch.eq(torch.sum(matcher, dim=-1), nclasses)
|
||||
not_matcher = torch.bitwise_not(matcher)
|
||||
|
||||
inf = torch.full_like(distances, fill_value=float('inf'))
|
||||
distances_to_wpluses = torch.where(matcher, distances, inf)
|
||||
distances_to_wminuses = torch.where(not_matcher, distances, inf)
|
||||
dpluses = torch.min(distances_to_wpluses, dim=1, keepdim=True).values
|
||||
dminuses = torch.min(distances_to_wminuses, dim=1, keepdim=True).values
|
||||
d_matching = torch.where(matcher, distances, inf)
|
||||
d_unmatching = torch.where(not_matcher, distances, inf)
|
||||
dp = torch.min(d_matching, dim=1, keepdim=True).values
|
||||
dm = torch.min(d_unmatching, dim=1, keepdim=True).values
|
||||
return dp, dm
|
||||
|
||||
mu = (dpluses - dminuses) / (dpluses + dminuses)
|
||||
|
||||
def glvq_loss(distances, target_labels, prototype_labels):
|
||||
"""GLVQ loss function with support for one-hot labels."""
|
||||
dp, dm = _get_dp_dm(distances, target_labels, prototype_labels)
|
||||
mu = (dp - dm) / (dp + dm)
|
||||
return mu
|
||||
|
@@ -0,0 +1,7 @@
|
||||
"""ProtoTorch modules."""
|
||||
|
||||
from .prototypes import Prototypes1D
|
||||
|
||||
__all__ = [
|
||||
'Prototypes1D',
|
||||
]
|
||||
|
@@ -7,7 +7,6 @@ from prototorch.functions.losses import glvq_loss
|
||||
|
||||
|
||||
class GLVQLoss(torch.nn.Module):
|
||||
"""GLVQ Loss."""
|
||||
def __init__(self, margin=0.0, squashing='identity', beta=10, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.margin = margin
|
||||
|
@@ -1,57 +1,165 @@
|
||||
"""ProtoTorch prototype modules."""
|
||||
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from prototorch.functions.competitions import wtac
|
||||
from prototorch.functions.distances import sed
|
||||
from prototorch.functions.initializers import get_initializer
|
||||
|
||||
|
||||
class AddPrototypes1D(torch.nn.Module):
|
||||
class _Prototypes(torch.nn.Module):
|
||||
"""Abstract prototypes class."""
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def _validate_prototype_distribution(self):
|
||||
if 0 in self.prototype_distribution:
|
||||
warnings.warn('Are you sure about the `0` in '
|
||||
'`prototype_distribution`?')
|
||||
|
||||
def extra_repr(self):
|
||||
return f'prototypes.shape: {tuple(self.prototypes.shape)}'
|
||||
|
||||
def forward(self):
|
||||
return self.prototypes, self.prototype_labels
|
||||
|
||||
|
||||
class Prototypes1D(_Prototypes):
|
||||
r"""Create a learnable set of one-dimensional prototypes.
|
||||
|
||||
TODO Complete this doc-string
|
||||
|
||||
Kwargs:
|
||||
prototypes_per_class: number of prototypes to use per class.
|
||||
Default: ``1``
|
||||
prototype_initializer: prototype initializer.
|
||||
Default: ``'ones'``
|
||||
prototype_distribution: prototype distribution vector.
|
||||
Default: ``None``
|
||||
input_dim: dimension of the incoming data.
|
||||
nclasses: number of classes.
|
||||
data: If set to ``None``, data-dependent initializers will be ignored.
|
||||
Default: ``None``
|
||||
|
||||
Shape:
|
||||
- Input: :math:`(N, H_{in})`
|
||||
where :math:`H_{in} = \text{input_dim}`.
|
||||
- Output: :math:`(N, H_{out})`
|
||||
where :math:`H_{out} = \text{total_prototypes}`.
|
||||
|
||||
Attributes:
|
||||
prototypes: the learnable weights of the module of shape
|
||||
:math:`(\text{total_prototypes}, \text{prototype_dimension})`.
|
||||
prototype_labels: the non-learnable labels of the prototypes.
|
||||
|
||||
Examples::
|
||||
|
||||
>>> p = Prototypes1D(input_dim=20, nclasses=10)
|
||||
>>> input = torch.randn(128, 20)
|
||||
>>> output = m(input)
|
||||
>>> print(output.size())
|
||||
torch.Size([20, 10])
|
||||
"""
|
||||
def __init__(self,
|
||||
prototypes_per_class=1,
|
||||
prototype_distribution=None,
|
||||
prototype_initializer='ones',
|
||||
prototype_distribution=None,
|
||||
data=None,
|
||||
dtype=torch.float32,
|
||||
one_hot_labels=False,
|
||||
**kwargs):
|
||||
|
||||
# Convert tensors to python lists before processing
|
||||
if prototype_distribution is not None:
|
||||
if not isinstance(prototype_distribution, list):
|
||||
prototype_distribution = prototype_distribution.tolist()
|
||||
|
||||
if data is None:
|
||||
if 'input_dim' not in kwargs:
|
||||
raise NameError('`input_dim` required if '
|
||||
'no `data` is provided.')
|
||||
if prototype_distribution is not None:
|
||||
nclasses = sum(prototype_distribution)
|
||||
if prototype_distribution:
|
||||
kwargs_nclasses = sum(prototype_distribution)
|
||||
else:
|
||||
if 'nclasses' not in kwargs:
|
||||
raise NameError('`prototype_distribution` required if '
|
||||
'both `data` and `nclasses` are not '
|
||||
'provided.')
|
||||
nclasses = kwargs.pop('nclasses')
|
||||
kwargs_nclasses = kwargs.pop('nclasses')
|
||||
input_dim = kwargs.pop('input_dim')
|
||||
# input_shape = (input_dim, )
|
||||
x_train = torch.rand(nclasses, input_dim)
|
||||
y_train = torch.arange(nclasses)
|
||||
if prototype_initializer in [
|
||||
'stratified_mean', 'stratified_random'
|
||||
]:
|
||||
warnings.warn(
|
||||
f'`prototype_initializer`: `{prototype_initializer}` '
|
||||
'requires `data`, but `data` is not provided. '
|
||||
'Using randomly generated data instead.')
|
||||
x_train = torch.rand(kwargs_nclasses, input_dim)
|
||||
y_train = torch.arange(kwargs_nclasses)
|
||||
if one_hot_labels:
|
||||
y_train = torch.eye(kwargs_nclasses)[y_train]
|
||||
data = [x_train, y_train]
|
||||
|
||||
else:
|
||||
x_train, y_train = data
|
||||
x_train = torch.as_tensor(x_train)
|
||||
y_train = torch.as_tensor(y_train)
|
||||
x_train = torch.as_tensor(x_train).type(dtype)
|
||||
y_train = torch.as_tensor(y_train).type(torch.int)
|
||||
nclasses = torch.unique(y_train, dim=-1).shape[-1]
|
||||
|
||||
if nclasses == 1:
|
||||
warnings.warn('Are you sure about having one class only?')
|
||||
|
||||
if x_train.ndim != 2:
|
||||
raise ValueError('`data[0].ndim != 2`.')
|
||||
|
||||
if y_train.ndim == 2:
|
||||
if y_train.shape[1] == 1 and one_hot_labels:
|
||||
raise ValueError('`one_hot_labels` is set to `True` '
|
||||
'but target labels are not one-hot-encoded.')
|
||||
if y_train.shape[1] != 1 and not one_hot_labels:
|
||||
raise ValueError('`one_hot_labels` is set to `False` '
|
||||
'but target labels in `data` '
|
||||
'are one-hot-encoded.')
|
||||
if y_train.ndim == 1 and one_hot_labels:
|
||||
raise ValueError('`one_hot_labels` is set to `True` '
|
||||
'but target labels are not one-hot-encoded.')
|
||||
|
||||
# Verify input dimension if `input_dim` is provided
|
||||
if 'input_dim' in kwargs:
|
||||
input_dim = kwargs.pop('input_dim')
|
||||
if input_dim != x_train.shape[1]:
|
||||
raise ValueError(f'Provided `input_dim`={input_dim} does '
|
||||
'not match data dimension '
|
||||
f'`data[0].shape[1]`={x_train.shape[1]}')
|
||||
|
||||
# Verify the number of classes if `nclasses` is provided
|
||||
if 'nclasses' in kwargs:
|
||||
kwargs_nclasses = kwargs.pop('nclasses')
|
||||
if kwargs_nclasses != nclasses:
|
||||
raise ValueError(f'Provided `nclasses={kwargs_nclasses}` does '
|
||||
'not match data labels '
|
||||
'`torch.unique(data[1]).shape[0]`'
|
||||
f'={nclasses}')
|
||||
|
||||
super().__init__(**kwargs)
|
||||
self.prototypes_per_class = prototypes_per_class
|
||||
with torch.no_grad():
|
||||
|
||||
if not prototype_distribution:
|
||||
num_classes = torch.unique(y_train).shape[0]
|
||||
self.prototype_distribution = torch.tensor(
|
||||
[self.prototypes_per_class] * num_classes)
|
||||
else:
|
||||
self.prototype_distribution = torch.tensor(
|
||||
prototype_distribution)
|
||||
prototype_distribution = [prototypes_per_class] * nclasses
|
||||
with torch.no_grad():
|
||||
self.prototype_distribution = torch.tensor(prototype_distribution)
|
||||
|
||||
self._validate_prototype_distribution()
|
||||
|
||||
self.prototype_initializer = get_initializer(prototype_initializer)
|
||||
prototypes, prototype_labels = self.prototype_initializer(
|
||||
x_train,
|
||||
y_train,
|
||||
prototype_distribution=self.prototype_distribution)
|
||||
prototype_distribution=self.prototype_distribution,
|
||||
one_hot=one_hot_labels,
|
||||
)
|
||||
|
||||
# Register module parameters
|
||||
self.prototypes = torch.nn.Parameter(prototypes)
|
||||
self.prototype_labels = prototype_labels
|
||||
|
||||
def forward(self):
|
||||
return self.prototypes, self.prototype_labels
|
||||
|
5
requirements.txt
Normal file
5
requirements.txt
Normal file
@@ -0,0 +1,5 @@
|
||||
matplotlib==3.1.2
|
||||
pytest==5.3.4
|
||||
requests==2.22.0
|
||||
codecov==2.0.22
|
||||
tqdm==4.44.1
|
3
setup.py
3
setup.py
@@ -10,7 +10,7 @@ with open('README.md', 'r') as fh:
|
||||
long_description = fh.read()
|
||||
|
||||
setup(name='prototorch',
|
||||
version='0.1.1-dev0',
|
||||
version='0.1.1-rc0',
|
||||
description='Highly extensible, GPU-supported '
|
||||
'Learning Vector Quantization (LVQ) toolbox '
|
||||
'built using PyTorch and its nn API.',
|
||||
@@ -27,6 +27,7 @@ setup(name='prototorch',
|
||||
'numpy>=1.9.1',
|
||||
],
|
||||
extras_require={
|
||||
'datasets': ['requests'],
|
||||
'examples': [
|
||||
'sklearn',
|
||||
'matplotlib',
|
||||
|
95
tests/test_datasets.py
Normal file
95
tests/test_datasets.py
Normal file
@@ -0,0 +1,95 @@
|
||||
"""ProtoTorch datasets test suite."""
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from prototorch.datasets import abstract, tecator
|
||||
|
||||
|
||||
class TestAbstract(unittest.TestCase):
|
||||
def test_getitem(self):
|
||||
with self.assertRaises(NotImplementedError):
|
||||
abstract.Dataset('./artifacts')[0]
|
||||
|
||||
def test_len(self):
|
||||
with self.assertRaises(NotImplementedError):
|
||||
len(abstract.Dataset('./artifacts'))
|
||||
|
||||
|
||||
class TestProtoDataset(unittest.TestCase):
|
||||
def test_getitem(self):
|
||||
with self.assertRaises(NotImplementedError):
|
||||
abstract.ProtoDataset('./artifacts')[0]
|
||||
|
||||
def test_download(self):
|
||||
with self.assertRaises(NotImplementedError):
|
||||
abstract.ProtoDataset('./artifacts').download()
|
||||
|
||||
|
||||
class TestTecator(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.artifacts_dir = './artifacts/Tecator'
|
||||
self._remove_artifacts()
|
||||
|
||||
def _remove_artifacts(self):
|
||||
if os.path.exists(self.artifacts_dir):
|
||||
shutil.rmtree(self.artifacts_dir)
|
||||
|
||||
def test_download_false(self):
|
||||
rootdir = self.artifacts_dir.rpartition('/')[0]
|
||||
self._remove_artifacts()
|
||||
with self.assertRaises(RuntimeError):
|
||||
_ = tecator.Tecator(rootdir, download=False)
|
||||
|
||||
def test_download_caching(self):
|
||||
rootdir = self.artifacts_dir.rpartition('/')[0]
|
||||
_ = tecator.Tecator(rootdir, download=True, verbose=False)
|
||||
_ = tecator.Tecator(rootdir, download=False, verbose=False)
|
||||
|
||||
def test_repr(self):
|
||||
rootdir = self.artifacts_dir.rpartition('/')[0]
|
||||
train = tecator.Tecator(rootdir, download=True, verbose=True)
|
||||
self.assertTrue('Split: Train' in train.__repr__())
|
||||
|
||||
def test_download_train(self):
|
||||
rootdir = self.artifacts_dir.rpartition('/')[0]
|
||||
train = tecator.Tecator(root=rootdir,
|
||||
train=True,
|
||||
download=True,
|
||||
verbose=False)
|
||||
train = tecator.Tecator(root=rootdir, download=True, verbose=False)
|
||||
x_train, y_train = train.data, train.targets
|
||||
self.assertEqual(x_train.shape[0], 144)
|
||||
self.assertEqual(y_train.shape[0], 144)
|
||||
self.assertEqual(x_train.shape[1], 100)
|
||||
|
||||
def test_download_test(self):
|
||||
rootdir = self.artifacts_dir.rpartition('/')[0]
|
||||
test = tecator.Tecator(root=rootdir, train=False, verbose=False)
|
||||
x_test, y_test = test.data, test.targets
|
||||
self.assertEqual(x_test.shape[0], 71)
|
||||
self.assertEqual(y_test.shape[0], 71)
|
||||
self.assertEqual(x_test.shape[1], 100)
|
||||
|
||||
def test_class_to_idx(self):
|
||||
rootdir = self.artifacts_dir.rpartition('/')[0]
|
||||
test = tecator.Tecator(root=rootdir, train=False, verbose=False)
|
||||
_ = test.class_to_idx
|
||||
|
||||
def test_getitem(self):
|
||||
rootdir = self.artifacts_dir.rpartition('/')[0]
|
||||
test = tecator.Tecator(root=rootdir, train=False, verbose=False)
|
||||
x, y = test[0]
|
||||
self.assertEqual(x.shape[0], 100)
|
||||
self.assertIsInstance(y, int)
|
||||
|
||||
def test_loadable_with_dataloader(self):
|
||||
rootdir = self.artifacts_dir.rpartition('/')[0]
|
||||
test = tecator.Tecator(root=rootdir, train=False, verbose=False)
|
||||
_ = torch.utils.data.DataLoader(test, batch_size=64, shuffle=True)
|
||||
|
||||
def tearDown(self):
|
||||
pass
|
@@ -85,6 +85,16 @@ class TestCompetitions(unittest.TestCase):
|
||||
decimal=5)
|
||||
self.assertIsNone(mismatch)
|
||||
|
||||
def test_wtac_unequal_dist(self):
|
||||
d = torch.tensor([[2., 3., 4.], [2., 3., 1.]])
|
||||
labels = torch.tensor([0, 1, 1])
|
||||
actual = competitions.wtac(d, labels)
|
||||
desired = torch.tensor([0, 1])
|
||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||
desired,
|
||||
decimal=5)
|
||||
self.assertIsNone(mismatch)
|
||||
|
||||
def test_wtac_one_hot(self):
|
||||
d = torch.tensor([[1.99, 3.01], [3., 2.01]])
|
||||
labels = torch.tensor([[0, 1], [1, 0]])
|
||||
@@ -95,6 +105,37 @@ class TestCompetitions(unittest.TestCase):
|
||||
decimal=5)
|
||||
self.assertIsNone(mismatch)
|
||||
|
||||
def test_stratified_min(self):
|
||||
d = torch.tensor([[1., 0., 2., 3.], [9., 8., 0, 1]])
|
||||
labels = torch.tensor([0, 0, 1, 2])
|
||||
actual = competitions.stratified_min(d, labels)
|
||||
desired = torch.tensor([[0., 2., 3.], [8., 0., 1.]])
|
||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||
desired,
|
||||
decimal=5)
|
||||
self.assertIsNone(mismatch)
|
||||
|
||||
def test_stratified_min_one_hot(self):
|
||||
d = torch.tensor([[1., 0., 2., 3.], [9., 8., 0, 1]])
|
||||
labels = torch.tensor([0, 0, 1, 2])
|
||||
labels = torch.eye(3)[labels]
|
||||
actual = competitions.stratified_min(d, labels)
|
||||
desired = torch.tensor([[0., 2., 3.], [8., 0., 1.]])
|
||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||
desired,
|
||||
decimal=5)
|
||||
self.assertIsNone(mismatch)
|
||||
|
||||
def test_stratified_min_simple(self):
|
||||
d = torch.tensor([[0., 2., 3.], [8., 0, 1]])
|
||||
labels = torch.tensor([0, 1, 2])
|
||||
actual = competitions.stratified_min(d, labels)
|
||||
desired = torch.tensor([[0., 2., 3.], [8., 0., 1.]])
|
||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||
desired,
|
||||
decimal=5)
|
||||
self.assertIsNone(mismatch)
|
||||
|
||||
def test_knnc_k1(self):
|
||||
d = torch.tensor([[2., 3., 1.99, 3.01], [2., 3., 2.01, 3.]])
|
||||
labels = torch.tensor([0, 1, 2, 3])
|
||||
@@ -341,7 +382,7 @@ class TestInitializers(unittest.TestCase):
|
||||
|
||||
def test_stratified_mean_equal1(self):
|
||||
pdist = torch.tensor([1, 1])
|
||||
actual, _ = initializers.stratified_mean(self.x, self.y, pdist)
|
||||
actual, _ = initializers.stratified_mean(self.x, self.y, pdist, False)
|
||||
desired = torch.tensor([[5., 5., 5.], [1., 1., 1.]])
|
||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||
desired,
|
||||
@@ -350,8 +391,9 @@ class TestInitializers(unittest.TestCase):
|
||||
|
||||
def test_stratified_random_equal1(self):
|
||||
pdist = torch.tensor([1, 1])
|
||||
actual, _ = initializers.stratified_random(self.x, self.y, pdist)
|
||||
desired = torch.tensor([[0., -1., -2.], [2., 2., 2.]])
|
||||
actual, _ = initializers.stratified_random(self.x, self.y, pdist,
|
||||
False)
|
||||
desired = torch.tensor([[0., -1., -2.], [0., 0., 0.]])
|
||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||
desired,
|
||||
decimal=5)
|
||||
@@ -359,7 +401,7 @@ class TestInitializers(unittest.TestCase):
|
||||
|
||||
def test_stratified_mean_equal2(self):
|
||||
pdist = torch.tensor([2, 2])
|
||||
actual, _ = initializers.stratified_mean(self.x, self.y, pdist)
|
||||
actual, _ = initializers.stratified_mean(self.x, self.y, pdist, False)
|
||||
desired = torch.tensor([[5., 5., 5.], [5., 5., 5.], [1., 1., 1.],
|
||||
[1., 1., 1.]])
|
||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||
@@ -367,9 +409,20 @@ class TestInitializers(unittest.TestCase):
|
||||
decimal=5)
|
||||
self.assertIsNone(mismatch)
|
||||
|
||||
def test_stratified_random_equal2(self):
|
||||
pdist = torch.tensor([2, 2])
|
||||
actual, _ = initializers.stratified_random(self.x, self.y, pdist,
|
||||
False)
|
||||
desired = torch.tensor([[0., -1., -2.], [0., -1., -2.], [0., 0., 0.],
|
||||
[0., 0., 0.]])
|
||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||
desired,
|
||||
decimal=5)
|
||||
self.assertIsNone(mismatch)
|
||||
|
||||
def test_stratified_mean_unequal(self):
|
||||
pdist = torch.tensor([1, 3])
|
||||
actual, _ = initializers.stratified_mean(self.x, self.y, pdist)
|
||||
actual, _ = initializers.stratified_mean(self.x, self.y, pdist, False)
|
||||
desired = torch.tensor([[5., 5., 5.], [1., 1., 1.], [1., 1., 1.],
|
||||
[1., 1., 1.]])
|
||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||
@@ -379,14 +432,45 @@ class TestInitializers(unittest.TestCase):
|
||||
|
||||
def test_stratified_random_unequal(self):
|
||||
pdist = torch.tensor([1, 3])
|
||||
actual, _ = initializers.stratified_random(self.x, self.y, pdist)
|
||||
desired = torch.tensor([[0., -1., -2.], [2., 2., 2.], [0., 0., 0.],
|
||||
actual, _ = initializers.stratified_random(self.x, self.y, pdist,
|
||||
False)
|
||||
desired = torch.tensor([[0., -1., -2.], [0., 0., 0.], [0., 0., 0.],
|
||||
[0., 0., 0.]])
|
||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||
desired,
|
||||
decimal=5)
|
||||
self.assertIsNone(mismatch)
|
||||
|
||||
def test_stratified_mean_unequal_one_hot(self):
|
||||
pdist = torch.tensor([1, 3])
|
||||
y = torch.eye(2)[self.y]
|
||||
desired1 = torch.tensor([[5., 5., 5.], [1., 1., 1.], [1., 1., 1.],
|
||||
[1., 1., 1.]])
|
||||
actual1, actual2 = initializers.stratified_mean(self.x, y, pdist)
|
||||
desired2 = torch.tensor([[1, 0], [0, 1], [0, 1], [0, 1]])
|
||||
mismatch = np.testing.assert_array_almost_equal(actual1,
|
||||
desired1,
|
||||
decimal=5)
|
||||
mismatch = np.testing.assert_array_almost_equal(actual2,
|
||||
desired2,
|
||||
decimal=5)
|
||||
self.assertIsNone(mismatch)
|
||||
|
||||
def test_stratified_random_unequal_one_hot(self):
|
||||
pdist = torch.tensor([1, 3])
|
||||
y = torch.eye(2)[self.y]
|
||||
actual1, actual2 = initializers.stratified_random(self.x, y, pdist)
|
||||
desired1 = torch.tensor([[0., -1., -2.], [0., 0., 0.], [0., 0., 0.],
|
||||
[0., 0., 0.]])
|
||||
desired2 = torch.tensor([[1, 0], [0, 1], [0, 1], [0, 1]])
|
||||
mismatch = np.testing.assert_array_almost_equal(actual1,
|
||||
desired1,
|
||||
decimal=5)
|
||||
mismatch = np.testing.assert_array_almost_equal(actual2,
|
||||
desired2,
|
||||
decimal=5)
|
||||
self.assertIsNone(mismatch)
|
||||
|
||||
def tearDown(self):
|
||||
del self.x, self.y, self.gen
|
||||
_ = torch.seed()
|
||||
@@ -417,5 +501,17 @@ class TestLosses(unittest.TestCase):
|
||||
loss_value = torch.sum(batch_loss, dim=0)
|
||||
self.assertEqual(loss_value, -100)
|
||||
|
||||
def test_glvq_loss_one_hot_unequal(self):
|
||||
dlist = [torch.ones(100), torch.zeros(100), torch.zeros(100)]
|
||||
d = torch.stack(dlist, dim=1)
|
||||
labels = torch.tensor([[0, 1], [1, 0], [1, 0]])
|
||||
wl = torch.tensor([1, 0])
|
||||
targets = torch.stack([wl for _ in range(100)], dim=0)
|
||||
batch_loss = losses.glvq_loss(distances=d,
|
||||
target_labels=targets,
|
||||
prototype_labels=labels)
|
||||
loss_value = torch.sum(batch_loss, dim=0)
|
||||
self.assertEqual(loss_value, -100)
|
||||
|
||||
def tearDown(self):
|
||||
pass
|
||||
|
@@ -5,7 +5,7 @@ import unittest
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from prototorch.modules import prototypes, losses
|
||||
from prototorch.modules import losses, prototypes
|
||||
|
||||
|
||||
class TestPrototypes(unittest.TestCase):
|
||||
@@ -16,16 +16,20 @@ class TestPrototypes(unittest.TestCase):
|
||||
self.y = torch.tensor([0, 0, 1, 1])
|
||||
self.gen = torch.manual_seed(42)
|
||||
|
||||
def test_addprototypes1d_init_without_input_dim(self):
|
||||
def test_prototypes1d_init_without_input_dim(self):
|
||||
with self.assertRaises(NameError):
|
||||
_ = prototypes.AddPrototypes1D(nclasses=1)
|
||||
_ = prototypes.Prototypes1D(nclasses=2)
|
||||
|
||||
def test_addprototypes1d_init_without_nclasses(self):
|
||||
def test_prototypes1d_init_without_nclasses(self):
|
||||
with self.assertRaises(NameError):
|
||||
_ = prototypes.AddPrototypes1D(input_dim=1)
|
||||
_ = prototypes.Prototypes1D(input_dim=1)
|
||||
|
||||
def test_addprototypes1d_init_without_pdist(self):
|
||||
p1 = prototypes.AddPrototypes1D(input_dim=6,
|
||||
def test_prototypes1d_init_with_nclasses_1(self):
|
||||
with self.assertWarns(UserWarning):
|
||||
_ = prototypes.Prototypes1D(nclasses=1, input_dim=1)
|
||||
|
||||
def test_prototypes1d_init_without_pdist(self):
|
||||
p1 = prototypes.Prototypes1D(input_dim=6,
|
||||
nclasses=2,
|
||||
prototypes_per_class=4,
|
||||
prototype_initializer='ones')
|
||||
@@ -37,9 +41,9 @@ class TestPrototypes(unittest.TestCase):
|
||||
decimal=5)
|
||||
self.assertIsNone(mismatch)
|
||||
|
||||
def test_addprototypes1d_init_without_data(self):
|
||||
def test_prototypes1d_init_without_data(self):
|
||||
pdist = [2, 2]
|
||||
p1 = prototypes.AddPrototypes1D(input_dim=3,
|
||||
p1 = prototypes.Prototypes1D(input_dim=3,
|
||||
prototype_distribution=pdist,
|
||||
prototype_initializer='zeros')
|
||||
protos = p1.prototypes
|
||||
@@ -50,21 +54,127 @@ class TestPrototypes(unittest.TestCase):
|
||||
decimal=5)
|
||||
self.assertIsNone(mismatch)
|
||||
|
||||
# def test_addprototypes1d_init_torch_pdist(self):
|
||||
# pdist = torch.tensor([2, 2])
|
||||
# p1 = prototypes.AddPrototypes1D(input_dim=3,
|
||||
# prototype_distribution=pdist,
|
||||
# prototype_initializer='zeros')
|
||||
# protos = p1.prototypes
|
||||
# actual = protos.detach().numpy()
|
||||
# desired = torch.zeros(4, 3)
|
||||
# mismatch = np.testing.assert_array_almost_equal(actual,
|
||||
# desired,
|
||||
# decimal=5)
|
||||
# self.assertIsNone(mismatch)
|
||||
def test_prototypes1d_proto_init_without_data(self):
|
||||
with self.assertWarns(UserWarning):
|
||||
_ = prototypes.Prototypes1D(
|
||||
input_dim=3,
|
||||
nclasses=2,
|
||||
prototypes_per_class=1,
|
||||
prototype_initializer='stratified_mean',
|
||||
data=None)
|
||||
|
||||
def test_addprototypes1d_init_with_ppc(self):
|
||||
p1 = prototypes.AddPrototypes1D(data=[self.x, self.y],
|
||||
def test_prototypes1d_init_torch_pdist(self):
|
||||
pdist = torch.tensor([2, 2])
|
||||
p1 = prototypes.Prototypes1D(input_dim=3,
|
||||
prototype_distribution=pdist,
|
||||
prototype_initializer='zeros')
|
||||
protos = p1.prototypes
|
||||
actual = protos.detach().numpy()
|
||||
desired = torch.zeros(4, 3)
|
||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||
desired,
|
||||
decimal=5)
|
||||
self.assertIsNone(mismatch)
|
||||
|
||||
def test_prototypes1d_init_without_inputdim_with_data(self):
|
||||
_ = prototypes.Prototypes1D(nclasses=2,
|
||||
prototypes_per_class=1,
|
||||
prototype_initializer='stratified_mean',
|
||||
data=[[[1.], [0.]], [1, 0]])
|
||||
|
||||
def test_prototypes1d_init_with_int_data(self):
|
||||
_ = prototypes.Prototypes1D(nclasses=2,
|
||||
prototypes_per_class=1,
|
||||
prototype_initializer='stratified_mean',
|
||||
data=[[[1], [0]], [1, 0]])
|
||||
|
||||
def test_prototypes1d_init_one_hot_without_data(self):
|
||||
_ = prototypes.Prototypes1D(input_dim=1,
|
||||
nclasses=2,
|
||||
prototypes_per_class=1,
|
||||
prototype_initializer='stratified_mean',
|
||||
data=None,
|
||||
one_hot_labels=True)
|
||||
|
||||
def test_prototypes1d_init_one_hot_labels_false(self):
|
||||
"""Test if ValueError is raised when `one_hot_labels` is set to `False`
|
||||
but the provided `data` has one-hot encoded labels.
|
||||
"""
|
||||
with self.assertRaises(ValueError):
|
||||
_ = prototypes.Prototypes1D(
|
||||
input_dim=1,
|
||||
nclasses=2,
|
||||
prototypes_per_class=1,
|
||||
prototype_initializer='stratified_mean',
|
||||
data=([[0.], [1.]], [[0, 1], [1, 0]]),
|
||||
one_hot_labels=False)
|
||||
|
||||
def test_prototypes1d_init_1d_y_data_one_hot_labels_true(self):
|
||||
"""Test if ValueError is raised when `one_hot_labels` is set to `True`
|
||||
but the provided `data` does not contain one-hot encoded labels.
|
||||
"""
|
||||
with self.assertRaises(ValueError):
|
||||
_ = prototypes.Prototypes1D(
|
||||
input_dim=1,
|
||||
nclasses=2,
|
||||
prototypes_per_class=1,
|
||||
prototype_initializer='stratified_mean',
|
||||
data=([[0.], [1.]], [0, 1]),
|
||||
one_hot_labels=True)
|
||||
|
||||
def test_prototypes1d_init_one_hot_labels_true(self):
|
||||
"""Test if ValueError is raised when `one_hot_labels` is set to `True`
|
||||
but the provided `data` contains 2D targets but
|
||||
does not contain one-hot encoded labels.
|
||||
"""
|
||||
with self.assertRaises(ValueError):
|
||||
_ = prototypes.Prototypes1D(
|
||||
input_dim=1,
|
||||
nclasses=2,
|
||||
prototypes_per_class=1,
|
||||
prototype_initializer='stratified_mean',
|
||||
data=([[0.], [1.]], [[0], [1]]),
|
||||
one_hot_labels=True)
|
||||
|
||||
def test_prototypes1d_init_with_int_dtype(self):
|
||||
with self.assertRaises(RuntimeError):
|
||||
_ = prototypes.Prototypes1D(
|
||||
nclasses=2,
|
||||
prototypes_per_class=1,
|
||||
prototype_initializer='stratified_mean',
|
||||
data=[[[1], [0]], [1, 0]],
|
||||
dtype=torch.int32)
|
||||
|
||||
def test_prototypes1d_inputndim_with_data(self):
|
||||
with self.assertRaises(ValueError):
|
||||
_ = prototypes.Prototypes1D(input_dim=1,
|
||||
nclasses=1,
|
||||
prototypes_per_class=1,
|
||||
data=[[1.], [1]])
|
||||
|
||||
def test_prototypes1d_inputdim_with_data(self):
|
||||
with self.assertRaises(ValueError):
|
||||
_ = prototypes.Prototypes1D(
|
||||
input_dim=2,
|
||||
nclasses=2,
|
||||
prototypes_per_class=1,
|
||||
prototype_initializer='stratified_mean',
|
||||
data=[[[1.], [0.]], [1, 0]])
|
||||
|
||||
def test_prototypes1d_nclasses_with_data(self):
|
||||
"""Test ValueError raise if provided `nclasses` is not the same
|
||||
as the one computed from the provided `data`.
|
||||
"""
|
||||
with self.assertRaises(ValueError):
|
||||
_ = prototypes.Prototypes1D(
|
||||
input_dim=1,
|
||||
nclasses=1,
|
||||
prototypes_per_class=1,
|
||||
prototype_initializer='stratified_mean',
|
||||
data=[[[1.], [2.]], [1, 2]])
|
||||
|
||||
def test_prototypes1d_init_with_ppc(self):
|
||||
p1 = prototypes.Prototypes1D(data=[self.x, self.y],
|
||||
prototypes_per_class=2,
|
||||
prototype_initializer='zeros')
|
||||
protos = p1.prototypes
|
||||
@@ -75,8 +185,8 @@ class TestPrototypes(unittest.TestCase):
|
||||
decimal=5)
|
||||
self.assertIsNone(mismatch)
|
||||
|
||||
def test_addprototypes1d_init_with_pdist(self):
|
||||
p1 = prototypes.AddPrototypes1D(data=[self.x, self.y],
|
||||
def test_prototypes1d_init_with_pdist(self):
|
||||
p1 = prototypes.Prototypes1D(data=[self.x, self.y],
|
||||
prototype_distribution=[6, 9],
|
||||
prototype_initializer='zeros')
|
||||
protos = p1.prototypes
|
||||
@@ -87,11 +197,11 @@ class TestPrototypes(unittest.TestCase):
|
||||
decimal=5)
|
||||
self.assertIsNone(mismatch)
|
||||
|
||||
def test_addprototypes1d_func_initializer(self):
|
||||
def test_prototypes1d_func_initializer(self):
|
||||
def my_initializer(*args, **kwargs):
|
||||
return torch.full((2, 99), 99), torch.tensor([0, 1])
|
||||
|
||||
p1 = prototypes.AddPrototypes1D(input_dim=99,
|
||||
p1 = prototypes.Prototypes1D(input_dim=99,
|
||||
nclasses=2,
|
||||
prototypes_per_class=1,
|
||||
prototype_initializer=my_initializer)
|
||||
@@ -103,8 +213,8 @@ class TestPrototypes(unittest.TestCase):
|
||||
decimal=5)
|
||||
self.assertIsNone(mismatch)
|
||||
|
||||
def test_addprototypes1d_forward(self):
|
||||
p1 = prototypes.AddPrototypes1D(data=[self.x, self.y])
|
||||
def test_prototypes1d_forward(self):
|
||||
p1 = prototypes.Prototypes1D(data=[self.x, self.y])
|
||||
protos, _ = p1()
|
||||
actual = protos.detach().numpy()
|
||||
desired = torch.ones(2, 3)
|
||||
@@ -113,6 +223,16 @@ class TestPrototypes(unittest.TestCase):
|
||||
decimal=5)
|
||||
self.assertIsNone(mismatch)
|
||||
|
||||
def test_prototypes1d_dist_validate(self):
|
||||
p1 = prototypes.Prototypes1D(input_dim=0, prototype_distribution=[0])
|
||||
with self.assertWarns(UserWarning):
|
||||
_ = p1._validate_prototype_distribution()
|
||||
|
||||
def test_prototypes1d_validate_extra_repr_not_empty(self):
|
||||
p1 = prototypes.Prototypes1D(input_dim=0, prototype_distribution=[0])
|
||||
rep = p1.extra_repr()
|
||||
self.assertNotEqual(rep, '')
|
||||
|
||||
def tearDown(self):
|
||||
del self.x, self.y, self.gen
|
||||
_ = torch.seed()
|
||||
|
Reference in New Issue
Block a user