43 Commits

Author SHA1 Message Date
blackfly
0cfbc0473b Bump version: 0.1.1-dev0 → 0.1.1-rc0 2020-04-27 12:56:42 +02:00
blackfly
cf0659d881 Add test cases to test newly added features 2020-04-27 12:49:54 +02:00
blackfly
d17b9a3346 Modify stratified_min function 2020-04-27 12:48:12 +02:00
blackfly
532f63b1de Add one-hot support in functions/initializers.py 2020-04-27 12:47:44 +02:00
blackfly
c11a3860df Refactor functions/losses.py 2020-04-27 12:47:15 +02:00
blackfly
dab91e471a Add minor cosmetic changes 2020-04-27 12:45:42 +02:00
blackfly
a167565857 Update Prototypes1D 2020-04-27 12:44:19 +02:00
blackfly
e063625486 Remove some requirements from requirements.txt 2020-04-15 12:12:44 +02:00
blackfly
89eb5358a0 Try fixing tqdm AttributeError 2020-04-14 20:26:49 +02:00
blackfly
5c59515128 Update github action 'tests' 2020-04-14 20:19:23 +02:00
blackfly
7eb7a6b194 Update .travis.yml 2020-04-14 20:19:15 +02:00
blackfly
5811c4b9f9 Add requirements.txt 2020-04-14 20:18:45 +02:00
blackfly
7b1887d56e Add 'requests' requirements for downloading datasets 2020-04-14 20:04:10 +02:00
blackfly
63a25e7a38 Refactor examples/glvq_iris.py 2020-04-14 19:57:19 +02:00
blackfly
a0f20a40f6 Add test cases to test recently added features 2020-04-14 19:53:51 +02:00
blackfly
88cbe0a126 Add alias for squared_euclidean_distance 2020-04-14 19:53:26 +02:00
blackfly
a3548e0ddd Add stratified_min competition function 2020-04-14 19:52:59 +02:00
blackfly
3cfbc49254 Fix generator bug in stratified_random initializer 2020-04-14 19:51:54 +02:00
blackfly
2b82830590 Add 'datasets' to main package __init__.py 2020-04-14 19:51:14 +02:00
blackfly
553b1e1a65 Refactor datasets and use float32 instead of float64 in Tecator 2020-04-14 19:49:59 +02:00
blackfly
a9d2855323 Refactor prototypes module and begin documentation 2020-04-14 19:48:46 +02:00
blackfly
cf7d7b5d9d Add tests/test_datasets.py 2020-04-14 19:47:59 +02:00
blackfly
a22c752342 Add prototorch/datasets 2020-04-14 19:47:34 +02:00
blackfly
4158586cb9 More cosmetic changes 2020-04-11 18:12:37 +02:00
blackfly
f80d9648c3 Minor cosmetic changes 2020-04-11 17:35:32 +02:00
blackfly
e54bf07030 Populate init files 2020-04-11 17:35:00 +02:00
blackfly
8c629c0cb1 Fix a bunch of codacy code-style issues 2020-04-11 15:47:26 +02:00
blackfly
8f3a43f62a Remove assert statements following codacy security recommendation
"Use of assert detected. The enclosed code will be removed when compiling to
optimised byte code."
2020-04-11 15:45:29 +02:00
blackfly
955661af95 Remove utils import from prototorch/__init__.py 2020-04-11 15:12:53 +02:00
blackfly
c54d14c55e Remove datasets import from prototorch/__init__.py 2020-04-11 14:59:11 +02:00
blackfly
6090aad176 Update examples/glvq_iris.py to use the recently modified API 2020-04-11 14:29:06 +02:00
blackfly
1ec7bd261b Add small API changes and more test cases 2020-04-11 14:28:22 +02:00
blackfly
da3b0cc262 Update RELEASE.md 2020-04-11 14:26:05 +02:00
blackfly
f640a22cf2 Rename input to x in activation functions 2020-04-11 14:25:35 +02:00
blackfly
c843ace63d Update README.md 2020-04-11 14:22:34 +02:00
blackfly
242c9de3b6 Fix codecov reporting in .travis.yml 2020-04-08 23:37:11 +02:00
blackfly
438a5b9360 Bump version: 0.1.0-rc0 → 0.1.1-dev0 2020-04-08 23:00:34 +02:00
blackfly
f98f3d095e Update .travis.yml to cache artifacts from test scripts 2020-04-08 22:47:31 +02:00
blackfly
21b0279839 Add test cases 2020-04-08 22:47:08 +02:00
blackfly
b19cbcb76a Fix zero-distance bug in glvq_loss 2020-04-08 22:46:08 +02:00
blackfly
7d5ab81dbf Clean up prototorch/functions/distances.py 2020-04-08 22:44:02 +02:00
blackfly
bde408a80e Prepare activation and competition functions for TorchScript 2020-04-08 22:42:56 +02:00
blackfly
900955d67a Rename tests github action 2020-04-08 22:34:26 +02:00
24 changed files with 1056 additions and 299 deletions

View File

@@ -1,5 +1,5 @@
[bumpversion] [bumpversion]
current_version = 0.1.0-rc0 current_version = 0.1.1-rc0
commit = True commit = True
tag = True tag = True
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)(\-(?P<release>[a-z]+)(?P<build>\d+))? parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)(\-(?P<release>[a-z]+)(?P<build>\d+))?

View File

@@ -1,7 +1,7 @@
# This workflow will install Python dependencies, run tests and lint with a single version of Python # This workflow will install Python dependencies, run tests and lint with a single version of Python
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
name: Tests name: tests
on: on:
push: push:
@@ -24,6 +24,9 @@ jobs:
run: | run: |
python -m pip install --upgrade pip python -m pip install --upgrade pip
pip install . pip install .
- name: Install extras
run: |
pip install -r requirements.txt
- name: Lint with flake8 - name: Lint with flake8
run: | run: |
pip install flake8 pip install flake8

View File

@@ -2,19 +2,17 @@ dist: bionic
sudo: false sudo: false
language: python language: python
python: 3.8 python: 3.8
# cache: cache:
# directories: directories:
# - $HOME/.prototorch - ./tests/artifacts
install: install:
- pip install . --progress-bar off - pip install . --progress-bar off
- pip install codecov - pip install -r requirements.txt
- pip install pytest
script: script:
- coverage run -m pytest - coverage run -m pytest
# Push the results to codecov # Push the results to codecov
after_success: after_success:
- codecov - bash <(curl -s https://codecov.io/bash)
# - bash <(curl -s https://codecov.io/bash)

View File

@@ -4,11 +4,12 @@ ProtoTorch is a PyTorch-based Python toolbox for bleeding-edge research in
prototype-based machine learning algorithms. prototype-based machine learning algorithms.
[![Build Status](https://travis-ci.org/si-cim/prototorch.svg?branch=master)](https://travis-ci.org/si-cim/prototorch) [![Build Status](https://travis-ci.org/si-cim/prototorch.svg?branch=master)](https://travis-ci.org/si-cim/prototorch)
[![GitHub version](https://badge.fury.io/gh/si-cim%2Fprototorch.svg)](https://badge.fury.io/gh/si-cim%2Fprototorch) ![tests](https://github.com/si-cim/prototorch/workflows/tests/badge.svg)
[![PyPI version](https://badge.fury.io/py/prototorch.svg)](https://badge.fury.io/py/prototorch) [![GitHub tag (latest by date)](https://img.shields.io/github/v/tag/si-cim/prototorch?color=yellow&label=version)](https://github.com/si-cim/prototorch/releases)
![Tests](https://github.com/si-cim/prototorch/workflows/Tests/badge.svg) [![PyPI](https://img.shields.io/pypi/v/prototorch)](https://pypi.org/project/prototorch/)
[![codecov](https://codecov.io/gh/si-cim/prototorch/branch/master/graph/badge.svg)](https://codecov.io/gh/si-cim/prototorch) [![codecov](https://codecov.io/gh/si-cim/prototorch/branch/master/graph/badge.svg)](https://codecov.io/gh/si-cim/prototorch)
[![Downloads](https://pepy.tech/badge/prototorch)](https://pepy.tech/project/prototorch) [![Codacy Badge](https://api.codacy.com/project/badge/Grade/76273904bf9343f0a8b29cd8aca242e7)](https://www.codacy.com/gh/si-cim/prototorch?utm_source=github.com&amp;utm_medium=referral&amp;utm_content=si-cim/prototorch&amp;utm_campaign=Badge_Grade)
![PyPI - Downloads](https://img.shields.io/pypi/dm/prototorch?color=blue)
[![GitHub license](https://img.shields.io/github/license/si-cim/prototorch)](https://github.com/si-cim/prototorch/blob/master/LICENSE) [![GitHub license](https://img.shields.io/github/license/si-cim/prototorch)](https://github.com/si-cim/prototorch/blob/master/LICENSE)
## Description ## Description
@@ -27,12 +28,12 @@ provided by PyTorch.
## Installation ## Installation
ProtoTorch can be installed using `pip`. ProtoTorch can be installed using `pip`.
``` ```bash
pip install prototorch pip install prototorch
``` ```
To install the bleeding-edge features and improvements: To install the bleeding-edge features and improvements:
``` ```bash
git clone https://github.com/si-cim/prototorch.git git clone https://github.com/si-cim/prototorch.git
git checkout dev git checkout dev
cd prototorch cd prototorch

View File

@@ -1,3 +1,11 @@
# Release 0.1.0-dev0 # ProtoTorch Releases
## Release 0.1.1-dev0
### Includes
- Minor bugfixes.
- 100% line coverage.
## Release 0.1.0-dev0
Initial public release of ProtoTorch. Initial public release of ProtoTorch.

View File

@@ -1,4 +1,4 @@
"""ProtoTorch GLVQ example using 2D Iris data""" """ProtoTorch GLVQ example using 2D Iris data."""
import numpy as np import numpy as np
import torch import torch
@@ -8,7 +8,7 @@ from sklearn.preprocessing import StandardScaler
from prototorch.functions.distances import euclidean_distance from prototorch.functions.distances import euclidean_distance
from prototorch.modules.losses import GLVQLoss from prototorch.modules.losses import GLVQLoss
from prototorch.modules.prototypes import AddPrototypes1D from prototorch.modules.prototypes import Prototypes1D
# Prepare and preprocess the data # Prepare and preprocess the data
scaler = StandardScaler() scaler = StandardScaler()
@@ -21,11 +21,12 @@ x_train = scaler.transform(x_train)
# Define the GLVQ model # Define the GLVQ model
class Model(torch.nn.Module): class Model(torch.nn.Module):
def __init__(self, **kwargs): def __init__(self, **kwargs):
"""GLVQ model."""
super().__init__() super().__init__()
self.p1 = AddPrototypes1D(input_dim=2, self.p1 = Prototypes1D(input_dim=2,
prototypes_per_class=1, prototypes_per_class=1,
nclasses=3, nclasses=3,
prototype_initializer='zeros') prototype_initializer='zeros')
def forward(self, x): def forward(self, x):
protos = self.p1.prototypes protos = self.p1.prototypes
@@ -41,13 +42,17 @@ model = Model()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01) optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
criterion = GLVQLoss(squashing='sigmoid_beta', beta=10) criterion = GLVQLoss(squashing='sigmoid_beta', beta=10)
x_in = torch.Tensor(x_train)
y_in = torch.Tensor(y_train)
# Training loop # Training loop
fig = plt.figure('Prototype Visualization') title = 'Prototype Visualization'
fig = plt.figure(title)
for epoch in range(70): for epoch in range(70):
# Compute loss. # Compute loss
distances, plabels = model(torch.tensor(x_train)) dis, plabels = model(x_in)
loss = criterion([distances, plabels], torch.tensor(y_train)) loss = criterion([dis, plabels], y_in)
print(f'Epoch: {epoch + 1:03d} Loss: {loss.item():02.02f}') print(f'Epoch: {epoch + 1:03d} Loss: {loss.item():05.02f}')
# Take a gradient descent step # Take a gradient descent step
optimizer.zero_grad() optimizer.zero_grad()
@@ -60,6 +65,9 @@ for epoch in range(70):
# Visualize the data and the prototypes # Visualize the data and the prototypes
ax = fig.gca() ax = fig.gca()
ax.cla() ax.cla()
ax.set_title(title)
ax.set_xlabel('Data dimension 1')
ax.set_ylabel('Data dimension 2')
cmap = 'viridis' cmap = 'viridis'
ax.scatter(x_train[:, 0], x_train[:, 1], c=y_train, edgecolor='k') ax.scatter(x_train[:, 0], x_train[:, 1], c=y_train, edgecolor='k')
ax.scatter(protos[:, 0], ax.scatter(protos[:, 0],
@@ -71,28 +79,17 @@ for epoch in range(70):
s=50) s=50)
# Paint decision regions # Paint decision regions
border = 1
resolution = 50
x = np.vstack((x_train, protos)) x = np.vstack((x_train, protos))
x_min, x_max = x[:, 0].min(), x[:, 0].max() x_min, x_max = x[:, 0].min() - 1, x[:, 0].max() + 1
y_min, y_max = x[:, 1].min(), x[:, 1].max() y_min, y_max = x[:, 1].min() - 1, x[:, 1].max() + 1
x_min, x_max = x_min - border, x_max + border xx, yy = np.meshgrid(np.arange(x_min, x_max, 1 / 50),
y_min, y_max = y_min - border, y_max + border np.arange(y_min, y_max, 1 / 50))
try:
xx, yy = np.meshgrid(np.arange(x_min, x_max, 1.0 / resolution),
np.arange(y_min, y_max, 1.0 / resolution))
except ValueError as ve:
print(ve)
raise ValueError(f'x_min: {x_min}, x_max: {x_max}. '
f'x_min - x_max is {x_max - x_min}.')
except MemoryError as me:
print(me)
raise ValueError('Too many points. ' 'Try reducing the resolution.')
mesh_input = np.c_[xx.ravel(), yy.ravel()] mesh_input = np.c_[xx.ravel(), yy.ravel()]
torch_input = torch.from_numpy(mesh_input) torch_input = torch.Tensor(mesh_input)
d = model(torch_input)[0] d = model(torch_input)[0]
y_pred = np.argmin(d.detach().numpy(), axis=1) y_pred = np.argmin(d.detach().numpy(),
axis=1) # assume one prototype per class
y_pred = y_pred.reshape(xx.shape) y_pred = y_pred.reshape(xx.shape)
# Plot voronoi regions # Plot voronoi regions
@@ -100,4 +97,5 @@ for epoch in range(70):
ax.set_xlim(left=x_min + 0, right=x_max - 0) ax.set_xlim(left=x_min + 0, right=x_max - 0)
ax.set_ylim(bottom=y_min + 0, top=y_max - 0) ax.set_ylim(bottom=y_min + 0, top=y_max - 0)
plt.pause(0.1) plt.pause(0.1)

View File

@@ -1 +1,11 @@
__version__ = '0.1.0-rc0' """ProtoTorch package."""
__version__ = '0.1.1-rc0'
from prototorch import datasets, functions, modules
__all__ = [
'datasets',
'functions',
'modules',
]

View File

@@ -0,0 +1,7 @@
"""ProtoTorch datasets."""
from .tecator import Tecator
__all__ = [
'Tecator',
]

View File

@@ -0,0 +1,87 @@
"""ProtoTorch abstract dataset classes.
Based on `torchvision.VisionDataset` and `torchvision.MNIST`
For the original code, see:
https://github.com/pytorch/vision/blob/master/torchvision/datasets/vision.py
https://github.com/pytorch/vision/blob/master/torchvision/datasets/mnist.py
"""
import os
import torch
class Dataset(torch.utils.data.Dataset):
"""Abstract dataset class to be inherited."""
_repr_indent = 2
def __init__(self, root):
if isinstance(root, torch._six.string_classes):
root = os.path.expanduser(root)
self.root = root
def __getitem__(self, index):
raise NotImplementedError
def __len__(self):
raise NotImplementedError
class ProtoDataset(Dataset):
"""Abstract dataset class to be inherited."""
training_file = 'training.pt'
test_file = 'test.pt'
def __init__(self, root, train=True, download=True, verbose=True):
super().__init__(root)
self.train = train # training set or test set
self.verbose = verbose
if download:
self.download()
if not self._check_exists():
raise RuntimeError('Dataset not found. '
'You can use download=True to download it')
data_file = self.training_file if self.train else self.test_file
self.data, self.targets = torch.load(
os.path.join(self.processed_folder, data_file))
@property
def raw_folder(self):
return os.path.join(self.root, self.__class__.__name__, 'raw')
@property
def processed_folder(self):
return os.path.join(self.root, self.__class__.__name__, 'processed')
@property
def class_to_idx(self):
return {_class: i for i, _class in enumerate(self.classes)}
def _check_exists(self):
return (os.path.exists(
os.path.join(self.processed_folder, self.training_file))
and os.path.exists(
os.path.join(self.processed_folder, self.test_file)))
def __repr__(self):
head = 'Dataset ' + self.__class__.__name__
body = ['Number of datapoints: {}'.format(self.__len__())]
if self.root is not None:
body.append('Root location: {}'.format(self.root))
body += self.extra_repr().splitlines()
lines = [head] + [' ' * self._repr_indent + line for line in body]
return '\n'.join(lines)
def extra_repr(self):
return f"Split: {'Train' if self.train is True else 'Test'}"
def __len__(self):
return len(self.data)
def download(self):
raise NotImplementedError

View File

@@ -0,0 +1,102 @@
"""Tecator dataset for classification.
URL:
http://lib.stat.cmu.edu/datasets/tecator
LICENCE / TERMS / COPYRIGHT:
This is the Tecator data set: The task is to predict the fat content
of a meat sample on the basis of its near infrared absorbance spectrum.
-------------------------------------------------------------------------
1. Statement of permission from Tecator (the original data source)
These data are recorded on a Tecator Infratec Food and Feed Analyzer
working in the wavelength range 850 - 1050 nm by the Near Infrared
Transmission (NIT) principle. Each sample contains finely chopped pure
meat with different moisture, fat and protein contents.
If results from these data are used in a publication we want you to
mention the instrument and company name (Tecator) in the publication.
In addition, please send a preprint of your article to
Karin Thente, Tecator AB,
Box 70, S-263 21 Hoganas, Sweden
The data are available in the public domain with no responsability from
the original data source. The data can be redistributed as long as this
permission note is attached.
For more information about the instrument - call Perstorp Analytical's
representative in your area.
Description:
For each meat sample the data consists of a 100 channel spectrum of
absorbances and the contents of moisture (water), fat and protein.
The absorbance is -log10 of the transmittance
measured by the spectrometer. The three contents, measured in percent,
are determined by analytic chemistry.
"""
import os
import numpy as np
import torch
from torchvision.datasets.utils import download_file_from_google_drive
from prototorch.datasets.abstract import ProtoDataset
class Tecator(ProtoDataset):
"""Tecator dataset for classification."""
resources = [
('1MMuUK8V41IgNpnPDbg3E-QAL6wlErTk0',
'ba5607c580d0f91bb27dc29d13c2f8df'),
] # (google_storage_id, md5hash)
classes = ['0 - low_fat', '1 - high_fat']
def __getitem__(self, index):
img, target = self.data[index], int(self.targets[index])
return img, target
def download(self):
"""Download the data if it doesn't exist in already."""
if self._check_exists():
return
if self.verbose:
print('Making directories...')
os.makedirs(self.raw_folder, exist_ok=True)
os.makedirs(self.processed_folder, exist_ok=True)
if self.verbose:
print('Downloading...')
for fileid, md5 in self.resources:
filename = 'tecator.npz'
download_file_from_google_drive(fileid,
root=self.raw_folder,
filename=filename,
md5=md5)
if self.verbose:
print('Processing...')
with np.load(os.path.join(self.raw_folder, 'tecator.npz'),
allow_pickle=False) as f:
x_train, y_train = f['x_train'], f['y_train']
x_test, y_test = f['x_test'], f['y_test']
training_set = [
torch.tensor(x_train, dtype=torch.float32),
torch.tensor(y_train),
]
test_set = [
torch.tensor(x_test, dtype=torch.float32),
torch.tensor(y_test),
]
with open(os.path.join(self.processed_folder, self.training_file),
'wb') as f:
torch.save(training_set, f)
with open(os.path.join(self.processed_folder, self.test_file),
'wb') as f:
torch.save(test_set, f)
if self.verbose:
print('Done!')

View File

@@ -0,0 +1,12 @@
"""ProtoTorch functions."""
from .activations import identity, sigmoid_beta, swish_beta
from .competitions import knnc, wtac
__all__ = [
'identity',
'sigmoid_beta',
'swish_beta',
'knnc',
'wtac',
]

View File

@@ -5,44 +5,60 @@ import torch
ACTIVATIONS = dict() ACTIVATIONS = dict()
def register_activation(func): # def register_activation(scriptf):
ACTIVATIONS[func.__name__] = func # ACTIVATIONS[scriptf.name] = scriptf
return func # return scriptf
def register_activation(function):
"""Add the activation function to the registry."""
ACTIVATIONS[function.__name__] = function
return function
@register_activation @register_activation
def identity(input, **kwargs): # @torch.jit.script
""":math:`f(x) = x`""" def identity(x, beta=torch.tensor(0)):
return input """Identity activation function.
Definition:
:math:`f(x) = x`
"""
return x
@register_activation @register_activation
def sigmoid_beta(input, beta=10): # @torch.jit.script
""":math:`f(x) = \\frac{1}{1 + e^{-\\beta x}}` def sigmoid_beta(x, beta=torch.tensor(10)):
r"""Sigmoid activation function with scaling.
Definition:
:math:`f(x) = \frac{1}{1 + e^{-\beta x}}`
Keyword Arguments: Keyword Arguments:
beta (float): Parameter :math:`\\beta` beta (`torch.tensor`): Scaling parameter :math:`\beta`
""" """
out = torch.reciprocal(1.0 + torch.exp(-beta * input)) out = torch.reciprocal(1.0 + torch.exp(-int(beta.item()) * x))
return out return out
@register_activation @register_activation
def swish_beta(input, beta=10): # @torch.jit.script
""":math:`f(x) = \\frac{x}{1 + e^{-\\beta x}}` def swish_beta(x, beta=torch.tensor(10)):
r"""Swish activation function with scaling.
Definition:
:math:`f(x) = \frac{x}{1 + e^{-\beta x}}`
Keyword Arguments: Keyword Arguments:
beta (float): Parameter :math:`\\beta` beta (`torch.tensor`): Scaling parameter :math:`\beta`
""" """
out = input * sigmoid_beta(input, beta=beta) out = x * sigmoid_beta(x, beta=beta)
return out return out
def get_activation(funcname): def get_activation(funcname):
"""Deserialize the activation function."""
if callable(funcname): if callable(funcname):
return funcname return funcname
else: if funcname in ACTIVATIONS:
if funcname in ACTIVATIONS: return ACTIVATIONS.get(funcname)
return ACTIVATIONS.get(funcname) raise NameError(f'Activation {funcname} was not found.')
else:
raise NameError(f'Activation {funcname} was not found.')

View File

@@ -3,13 +3,43 @@
import torch import torch
# @torch.jit.script
def stratified_min(distances, labels):
clabels = torch.unique(labels, dim=0)
nclasses = clabels.size()[0]
if distances.size()[1] == nclasses:
# skip if only one prototype per class
return distances
batch_size = distances.size()[0]
winning_distances = torch.zeros(nclasses, batch_size)
inf = torch.full_like(distances.T, fill_value=float('inf'))
# distances_to_wpluses = torch.where(matcher, distances, inf)
for i, cl in enumerate(clabels):
# cdists = distances.T[labels == cl]
matcher = torch.eq(labels.unsqueeze(dim=1), cl)
if labels.ndim == 2:
# if the labels are one-hot vectors
matcher = torch.eq(torch.sum(matcher, dim=-1), nclasses)
cdists = torch.where(matcher, distances.T, inf).T
winning_distances[i] = torch.min(cdists, dim=1,
keepdim=True).values.squeeze()
if labels.ndim == 2:
# Transpose to return with `batch_size` first and
# reverse the columns to fix the ordering of the classes
return torch.flip(winning_distances.T, dims=(1, ))
return winning_distances.T # return with `batch_size` first
# @torch.jit.script
def wtac(distances, labels): def wtac(distances, labels):
winning_indices = torch.min(distances, dim=1).indices winning_indices = torch.min(distances, dim=1).indices
winning_labels = labels[winning_indices].squeeze() winning_labels = labels[winning_indices].squeeze()
return winning_labels return winning_labels
# @torch.jit.script
def knnc(distances, labels, k): def knnc(distances, labels, k):
winning_indices = torch.topk(-distances, k=k, dim=1).indices winning_indices = torch.topk(-distances, k=k.item(), dim=1).indices
winning_labels = labels[winning_indices].squeeze() winning_labels = labels[winning_indices].squeeze()
return winning_labels return winning_labels

View File

@@ -28,26 +28,19 @@ def euclidean_distance(x, y):
def lpnorm_distance(x, y, p): def lpnorm_distance(x, y, p):
"""Compute :math:`{\\langle x, y \\rangle}_p`. r"""Compute :math:`{\langle x, y \rangle}_p`.
Expected dimension of x is 2. Expected dimension of x is 2.
Expected dimension of y is 2. Expected dimension of y is 2.
""" """
# # DEPRECATED in favor of torch.cdist
# expanded_x = x.unsqueeze(dim=1)
# batchwise_difference = y - expanded_x
# differences_raised = torch.pow(batchwise_difference, p)
# distances_raised = torch.sum(differences_raised, axis=2)
# distances = torch.pow(distances_raised, 1.0 / p)
# return distances
distances = torch.cdist(x, y, p=p) distances = torch.cdist(x, y, p=p)
return distances return distances
def omega_distance(x, y, omega): def omega_distance(x, y, omega):
"""Omega distance. r"""Omega distance.
Compute :math:`{\\langle \\Omega x, \\Omega y \\rangle}_p` Compute :math:`{\langle \Omega x, \Omega y \rangle}_p`
Expected dimension of x is 2. Expected dimension of x is 2.
Expected dimension of y is 2. Expected dimension of y is 2.
@@ -60,9 +53,9 @@ def omega_distance(x, y, omega):
def lomega_distance(x, y, omegas): def lomega_distance(x, y, omegas):
"""Localized Omega distance. r"""Localized Omega distance.
Compute :math:`{\\langle \\Omega_k x, \\Omega_k y_k \\rangle}_p` Compute :math:`{\langle \Omega_k x, \Omega_k y_k \rangle}_p`
Expected dimension of x is 2. Expected dimension of x is 2.
Expected dimension of y is 2. Expected dimension of y is 2.
@@ -76,3 +69,7 @@ def lomega_distance(x, y, omegas):
distances = torch.sum(differences_squared, dim=2) distances = torch.sum(differences_squared, dim=2)
distances = distances.permute(1, 0) distances = distances.permute(1, 0)
return distances return distances
# Aliases
sed = squared_euclidean_distance

View File

@@ -7,87 +7,97 @@ import torch
INITIALIZERS = dict() INITIALIZERS = dict()
def register_initializer(func): def register_initializer(function):
INITIALIZERS[func.__name__] = func """Add the initializer to the registry."""
return func INITIALIZERS[function.__name__] = function
return function
def labels_from(distribution): def labels_from(distribution, one_hot=True):
"""Takes a distribution tensor and returns a labels tensor.""" """Takes a distribution tensor and returns a labels tensor."""
nclasses = distribution.shape[0] nclasses = distribution.shape[0]
llist = [[i] * n for i, n in zip(range(nclasses), distribution)] llist = [[i] * n for i, n in zip(range(nclasses), distribution)]
# labels = [l for cl in llist for l in cl] # flatten the list of lists # labels = [l for cl in llist for l in cl] # flatten the list of lists
labels = list(chain(*llist)) # flatten using itertools.chain flat_llist = list(chain(*llist)) # flatten label list with itertools.chain
return torch.tensor(labels, requires_grad=False) plabels = torch.tensor(flat_llist, requires_grad=False)
if one_hot:
return torch.eye(nclasses)[plabels]
return plabels
@register_initializer @register_initializer
def ones(x_train, y_train, prototype_distribution): def ones(x_train, y_train, prototype_distribution, one_hot=True):
nprotos = torch.sum(prototype_distribution) nprotos = torch.sum(prototype_distribution)
protos = torch.ones(nprotos, *x_train.shape[1:]) protos = torch.ones(nprotos, *x_train.shape[1:])
plabels = labels_from(prototype_distribution) plabels = labels_from(prototype_distribution, one_hot)
return protos, plabels return protos, plabels
@register_initializer @register_initializer
def zeros(x_train, y_train, prototype_distribution): def zeros(x_train, y_train, prototype_distribution, one_hot=True):
nprotos = torch.sum(prototype_distribution) nprotos = torch.sum(prototype_distribution)
protos = torch.zeros(nprotos, *x_train.shape[1:]) protos = torch.zeros(nprotos, *x_train.shape[1:])
plabels = labels_from(prototype_distribution) plabels = labels_from(prototype_distribution, one_hot)
return protos, plabels return protos, plabels
@register_initializer @register_initializer
def rand(x_train, y_train, prototype_distribution): def rand(x_train, y_train, prototype_distribution, one_hot=True):
nprotos = torch.sum(prototype_distribution) nprotos = torch.sum(prototype_distribution)
protos = torch.rand(nprotos, *x_train.shape[1:]) protos = torch.rand(nprotos, *x_train.shape[1:])
plabels = labels_from(prototype_distribution) plabels = labels_from(prototype_distribution, one_hot)
return protos, plabels return protos, plabels
@register_initializer @register_initializer
def randn(x_train, y_train, prototype_distribution): def randn(x_train, y_train, prototype_distribution, one_hot=True):
nprotos = torch.sum(prototype_distribution) nprotos = torch.sum(prototype_distribution)
protos = torch.randn(nprotos, *x_train.shape[1:]) protos = torch.randn(nprotos, *x_train.shape[1:])
plabels = labels_from(prototype_distribution) plabels = labels_from(prototype_distribution, one_hot)
return protos, plabels return protos, plabels
@register_initializer @register_initializer
def stratified_mean(x_train, y_train, prototype_distribution): def stratified_mean(x_train, y_train, prototype_distribution, one_hot=True):
nprotos = torch.sum(prototype_distribution) nprotos = torch.sum(prototype_distribution)
pdim = x_train.shape[1] pdim = x_train.shape[1]
protos = torch.empty(nprotos, pdim) protos = torch.empty(nprotos, pdim)
plabels = labels_from(prototype_distribution) plabels = labels_from(prototype_distribution, one_hot)
for i, l in enumerate(plabels): for i, label in enumerate(plabels):
xl = x_train[y_train == l] matcher = torch.eq(label.unsqueeze(dim=0), y_train)
if one_hot:
nclasses = y_train.size()[1]
matcher = torch.eq(torch.sum(matcher, dim=-1), nclasses)
xl = x_train[matcher]
mean_xl = torch.mean(xl, dim=0) mean_xl = torch.mean(xl, dim=0)
protos[i] = mean_xl protos[i] = mean_xl
plabels = labels_from(prototype_distribution, one_hot=one_hot)
return protos, plabels return protos, plabels
@register_initializer @register_initializer
def stratified_random(x_train, y_train, prototype_distribution): def stratified_random(x_train, y_train, prototype_distribution, one_hot=True):
gen = torch.manual_seed(torch.initial_seed())
nprotos = torch.sum(prototype_distribution) nprotos = torch.sum(prototype_distribution)
pdim = x_train.shape[1] pdim = x_train.shape[1]
protos = torch.empty(nprotos, pdim) protos = torch.empty(nprotos, pdim)
plabels = labels_from(prototype_distribution) plabels = labels_from(prototype_distribution, one_hot)
for i, l in enumerate(plabels): for i, label in enumerate(plabels):
xl = x_train[y_train == l] matcher = torch.eq(label.unsqueeze(dim=0), y_train)
rand_index = torch.zeros(1).long().random_(0, if one_hot:
xl.shape[1] - 1, nclasses = y_train.size()[1]
generator=gen) matcher = torch.eq(torch.sum(matcher, dim=-1), nclasses)
xl = x_train[matcher]
rand_index = torch.zeros(1).long().random_(0, xl.shape[0] - 1)
random_xl = xl[rand_index] random_xl = xl[rand_index]
protos[i] = random_xl protos[i] = random_xl
plabels = labels_from(prototype_distribution, one_hot=one_hot)
return protos, plabels return protos, plabels
def get_initializer(funcname): def get_initializer(funcname):
"""Deserialize the initializer."""
if callable(funcname): if callable(funcname):
return funcname return funcname
else: if funcname in INITIALIZERS:
if funcname in INITIALIZERS: return INITIALIZERS.get(funcname)
return INITIALIZERS.get(funcname) raise NameError(f'Initializer {funcname} was not found.')
else:
raise NameError(f'Initializer {funcname} was not found.')

View File

@@ -3,23 +3,24 @@
import torch import torch
def glvq_loss(distances, target_labels, prototype_labels): def _get_dp_dm(distances, targets, plabels):
"""GLVQ loss function with support for one-hot labels.""" matcher = torch.eq(targets.unsqueeze(dim=1), plabels)
matcher = torch.eq(target_labels.unsqueeze(dim=1), prototype_labels) if plabels.ndim == 2:
if prototype_labels.ndim == 2:
# if the labels are one-hot vectors # if the labels are one-hot vectors
nclasses = target_labels.size()[1] nclasses = targets.size()[1]
matcher = torch.eq(torch.sum(matcher, dim=-1), nclasses) matcher = torch.eq(torch.sum(matcher, dim=-1), nclasses)
not_matcher = torch.bitwise_not(matcher) not_matcher = torch.bitwise_not(matcher)
dplus_criterion = distances * matcher > 0.0
dminus_criterion = distances * not_matcher > 0.0
inf = torch.full_like(distances, fill_value=float('inf')) inf = torch.full_like(distances, fill_value=float('inf'))
distances_to_wpluses = torch.where(dplus_criterion, distances, inf) d_matching = torch.where(matcher, distances, inf)
distances_to_wminuses = torch.where(dminus_criterion, distances, inf) d_unmatching = torch.where(not_matcher, distances, inf)
dpluses = torch.min(distances_to_wpluses, dim=1, keepdim=True).values dp = torch.min(d_matching, dim=1, keepdim=True).values
dminuses = torch.min(distances_to_wminuses, dim=1, keepdim=True).values dm = torch.min(d_unmatching, dim=1, keepdim=True).values
return dp, dm
mu = (dpluses - dminuses) / (dpluses + dminuses)
def glvq_loss(distances, target_labels, prototype_labels):
"""GLVQ loss function with support for one-hot labels."""
dp, dm = _get_dp_dm(distances, target_labels, prototype_labels)
mu = (dp - dm) / (dp + dm)
return mu return mu

View File

@@ -0,0 +1,7 @@
"""ProtoTorch modules."""
from .prototypes import Prototypes1D
__all__ = [
'Prototypes1D',
]

View File

@@ -7,12 +7,11 @@ from prototorch.functions.losses import glvq_loss
class GLVQLoss(torch.nn.Module): class GLVQLoss(torch.nn.Module):
"""GLVQ Loss."""
def __init__(self, margin=0.0, squashing='identity', beta=10, **kwargs): def __init__(self, margin=0.0, squashing='identity', beta=10, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.margin = margin self.margin = margin
self.squashing = get_activation(squashing) self.squashing = get_activation(squashing)
self.beta = beta self.beta = torch.tensor(beta)
def forward(self, outputs, targets): def forward(self, outputs, targets):
distances, plabels = outputs distances, plabels = outputs

View File

@@ -1,57 +1,165 @@
"""ProtoTorch prototype modules.""" """ProtoTorch prototype modules."""
import warnings
import numpy as np
import torch import torch
from prototorch.functions.competitions import wtac
from prototorch.functions.distances import sed
from prototorch.functions.initializers import get_initializer from prototorch.functions.initializers import get_initializer
class AddPrototypes1D(torch.nn.Module): class _Prototypes(torch.nn.Module):
"""Abstract prototypes class."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def _validate_prototype_distribution(self):
if 0 in self.prototype_distribution:
warnings.warn('Are you sure about the `0` in '
'`prototype_distribution`?')
def extra_repr(self):
return f'prototypes.shape: {tuple(self.prototypes.shape)}'
def forward(self):
return self.prototypes, self.prototype_labels
class Prototypes1D(_Prototypes):
r"""Create a learnable set of one-dimensional prototypes.
TODO Complete this doc-string
Kwargs:
prototypes_per_class: number of prototypes to use per class.
Default: ``1``
prototype_initializer: prototype initializer.
Default: ``'ones'``
prototype_distribution: prototype distribution vector.
Default: ``None``
input_dim: dimension of the incoming data.
nclasses: number of classes.
data: If set to ``None``, data-dependent initializers will be ignored.
Default: ``None``
Shape:
- Input: :math:`(N, H_{in})`
where :math:`H_{in} = \text{input_dim}`.
- Output: :math:`(N, H_{out})`
where :math:`H_{out} = \text{total_prototypes}`.
Attributes:
prototypes: the learnable weights of the module of shape
:math:`(\text{total_prototypes}, \text{prototype_dimension})`.
prototype_labels: the non-learnable labels of the prototypes.
Examples::
>>> p = Prototypes1D(input_dim=20, nclasses=10)
>>> input = torch.randn(128, 20)
>>> output = m(input)
>>> print(output.size())
torch.Size([20, 10])
"""
def __init__(self, def __init__(self,
prototypes_per_class=1, prototypes_per_class=1,
prototype_distribution=None,
prototype_initializer='ones', prototype_initializer='ones',
prototype_distribution=None,
data=None, data=None,
dtype=torch.float32,
one_hot_labels=False,
**kwargs): **kwargs):
# Convert tensors to python lists before processing
if prototype_distribution is not None:
if not isinstance(prototype_distribution, list):
prototype_distribution = prototype_distribution.tolist()
if data is None: if data is None:
if 'input_dim' not in kwargs: if 'input_dim' not in kwargs:
raise NameError('`input_dim` required if ' raise NameError('`input_dim` required if '
'no `data` is provided.') 'no `data` is provided.')
if prototype_distribution is not None: if prototype_distribution:
nclasses = sum(prototype_distribution) kwargs_nclasses = sum(prototype_distribution)
else: else:
if 'nclasses' not in kwargs: if 'nclasses' not in kwargs:
raise NameError('`prototype_distribution` required if ' raise NameError('`prototype_distribution` required if '
'both `data` and `nclasses` are not ' 'both `data` and `nclasses` are not '
'provided.') 'provided.')
nclasses = kwargs.pop('nclasses') kwargs_nclasses = kwargs.pop('nclasses')
input_dim = kwargs.pop('input_dim') input_dim = kwargs.pop('input_dim')
# input_shape = (input_dim, ) if prototype_initializer in [
x_train = torch.rand(nclasses, input_dim) 'stratified_mean', 'stratified_random'
y_train = torch.arange(nclasses) ]:
warnings.warn(
f'`prototype_initializer`: `{prototype_initializer}` '
'requires `data`, but `data` is not provided. '
'Using randomly generated data instead.')
x_train = torch.rand(kwargs_nclasses, input_dim)
y_train = torch.arange(kwargs_nclasses)
if one_hot_labels:
y_train = torch.eye(kwargs_nclasses)[y_train]
data = [x_train, y_train]
else: x_train, y_train = data
x_train, y_train = data x_train = torch.as_tensor(x_train).type(dtype)
x_train = torch.as_tensor(x_train) y_train = torch.as_tensor(y_train).type(torch.int)
y_train = torch.as_tensor(y_train) nclasses = torch.unique(y_train, dim=-1).shape[-1]
if nclasses == 1:
warnings.warn('Are you sure about having one class only?')
if x_train.ndim != 2:
raise ValueError('`data[0].ndim != 2`.')
if y_train.ndim == 2:
if y_train.shape[1] == 1 and one_hot_labels:
raise ValueError('`one_hot_labels` is set to `True` '
'but target labels are not one-hot-encoded.')
if y_train.shape[1] != 1 and not one_hot_labels:
raise ValueError('`one_hot_labels` is set to `False` '
'but target labels in `data` '
'are one-hot-encoded.')
if y_train.ndim == 1 and one_hot_labels:
raise ValueError('`one_hot_labels` is set to `True` '
'but target labels are not one-hot-encoded.')
# Verify input dimension if `input_dim` is provided
if 'input_dim' in kwargs:
input_dim = kwargs.pop('input_dim')
if input_dim != x_train.shape[1]:
raise ValueError(f'Provided `input_dim`={input_dim} does '
'not match data dimension '
f'`data[0].shape[1]`={x_train.shape[1]}')
# Verify the number of classes if `nclasses` is provided
if 'nclasses' in kwargs:
kwargs_nclasses = kwargs.pop('nclasses')
if kwargs_nclasses != nclasses:
raise ValueError(f'Provided `nclasses={kwargs_nclasses}` does '
'not match data labels '
'`torch.unique(data[1]).shape[0]`'
f'={nclasses}')
super().__init__(**kwargs) super().__init__(**kwargs)
self.prototypes_per_class = prototypes_per_class
if not prototype_distribution:
prototype_distribution = [prototypes_per_class] * nclasses
with torch.no_grad(): with torch.no_grad():
if not prototype_distribution: self.prototype_distribution = torch.tensor(prototype_distribution)
num_classes = torch.unique(y_train).shape[0]
self.prototype_distribution = torch.tensor( self._validate_prototype_distribution()
[self.prototypes_per_class] * num_classes)
else:
self.prototype_distribution = torch.tensor(
prototype_distribution)
self.prototype_initializer = get_initializer(prototype_initializer) self.prototype_initializer = get_initializer(prototype_initializer)
prototypes, prototype_labels = self.prototype_initializer( prototypes, prototype_labels = self.prototype_initializer(
x_train, x_train,
y_train, y_train,
prototype_distribution=self.prototype_distribution) prototype_distribution=self.prototype_distribution,
one_hot=one_hot_labels,
)
# Register module parameters
self.prototypes = torch.nn.Parameter(prototypes) self.prototypes = torch.nn.Parameter(prototypes)
self.prototype_labels = prototype_labels self.prototype_labels = prototype_labels
def forward(self):
return self.prototypes, self.prototype_labels

5
requirements.txt Normal file
View File

@@ -0,0 +1,5 @@
matplotlib==3.1.2
pytest==5.3.4
requests==2.22.0
codecov==2.0.22
tqdm==4.44.1

View File

@@ -10,7 +10,7 @@ with open('README.md', 'r') as fh:
long_description = fh.read() long_description = fh.read()
setup(name='prototorch', setup(name='prototorch',
version='0.1.0-rc0', version='0.1.1-rc0',
description='Highly extensible, GPU-supported ' description='Highly extensible, GPU-supported '
'Learning Vector Quantization (LVQ) toolbox ' 'Learning Vector Quantization (LVQ) toolbox '
'built using PyTorch and its nn API.', 'built using PyTorch and its nn API.',
@@ -27,6 +27,7 @@ setup(name='prototorch',
'numpy>=1.9.1', 'numpy>=1.9.1',
], ],
extras_require={ extras_require={
'datasets': ['requests'],
'examples': [ 'examples': [
'sklearn', 'sklearn',
'matplotlib', 'matplotlib',

95
tests/test_datasets.py Normal file
View File

@@ -0,0 +1,95 @@
"""ProtoTorch datasets test suite."""
import os
import shutil
import unittest
import torch
from prototorch.datasets import abstract, tecator
class TestAbstract(unittest.TestCase):
def test_getitem(self):
with self.assertRaises(NotImplementedError):
abstract.Dataset('./artifacts')[0]
def test_len(self):
with self.assertRaises(NotImplementedError):
len(abstract.Dataset('./artifacts'))
class TestProtoDataset(unittest.TestCase):
def test_getitem(self):
with self.assertRaises(NotImplementedError):
abstract.ProtoDataset('./artifacts')[0]
def test_download(self):
with self.assertRaises(NotImplementedError):
abstract.ProtoDataset('./artifacts').download()
class TestTecator(unittest.TestCase):
def setUp(self):
self.artifacts_dir = './artifacts/Tecator'
self._remove_artifacts()
def _remove_artifacts(self):
if os.path.exists(self.artifacts_dir):
shutil.rmtree(self.artifacts_dir)
def test_download_false(self):
rootdir = self.artifacts_dir.rpartition('/')[0]
self._remove_artifacts()
with self.assertRaises(RuntimeError):
_ = tecator.Tecator(rootdir, download=False)
def test_download_caching(self):
rootdir = self.artifacts_dir.rpartition('/')[0]
_ = tecator.Tecator(rootdir, download=True, verbose=False)
_ = tecator.Tecator(rootdir, download=False, verbose=False)
def test_repr(self):
rootdir = self.artifacts_dir.rpartition('/')[0]
train = tecator.Tecator(rootdir, download=True, verbose=True)
self.assertTrue('Split: Train' in train.__repr__())
def test_download_train(self):
rootdir = self.artifacts_dir.rpartition('/')[0]
train = tecator.Tecator(root=rootdir,
train=True,
download=True,
verbose=False)
train = tecator.Tecator(root=rootdir, download=True, verbose=False)
x_train, y_train = train.data, train.targets
self.assertEqual(x_train.shape[0], 144)
self.assertEqual(y_train.shape[0], 144)
self.assertEqual(x_train.shape[1], 100)
def test_download_test(self):
rootdir = self.artifacts_dir.rpartition('/')[0]
test = tecator.Tecator(root=rootdir, train=False, verbose=False)
x_test, y_test = test.data, test.targets
self.assertEqual(x_test.shape[0], 71)
self.assertEqual(y_test.shape[0], 71)
self.assertEqual(x_test.shape[1], 100)
def test_class_to_idx(self):
rootdir = self.artifacts_dir.rpartition('/')[0]
test = tecator.Tecator(root=rootdir, train=False, verbose=False)
_ = test.class_to_idx
def test_getitem(self):
rootdir = self.artifacts_dir.rpartition('/')[0]
test = tecator.Tecator(root=rootdir, train=False, verbose=False)
x, y = test[0]
self.assertEqual(x.shape[0], 100)
self.assertIsInstance(y, int)
def test_loadable_with_dataloader(self):
rootdir = self.artifacts_dir.rpartition('/')[0]
test = tecator.Tecator(root=rootdir, train=False, verbose=False)
_ = torch.utils.data.DataLoader(test, batch_size=64, shuffle=True)
def tearDown(self):
pass

View File

@@ -6,7 +6,148 @@ import numpy as np
import torch import torch
from prototorch.functions import (activations, competitions, distances, from prototorch.functions import (activations, competitions, distances,
initializers) initializers, losses)
class TestActivations(unittest.TestCase):
def setUp(self):
self.flist = ['identity', 'sigmoid_beta', 'swish_beta']
self.x = torch.randn(1024, 1)
def test_registry(self):
self.assertIsNotNone(activations.ACTIVATIONS)
def test_funcname_deserialization(self):
for funcname in self.flist:
f = activations.get_activation(funcname)
iscallable = callable(f)
self.assertTrue(iscallable)
# def test_torch_script(self):
# for funcname in self.flist:
# f = activations.get_activation(funcname)
# self.assertIsInstance(f, torch.jit.ScriptFunction)
def test_callable_deserialization(self):
def dummy(x, **kwargs):
return x
for f in [dummy, lambda x: x]:
f = activations.get_activation(f)
iscallable = callable(f)
self.assertTrue(iscallable)
self.assertEqual(1, f(1))
def test_unknown_deserialization(self):
for funcname in ['blubb', 'foobar']:
with self.assertRaises(NameError):
_ = activations.get_activation(funcname)
def test_identity(self):
actual = activations.identity(self.x)
desired = self.x
mismatch = np.testing.assert_array_almost_equal(actual,
desired,
decimal=5)
self.assertIsNone(mismatch)
def test_sigmoid_beta1(self):
actual = activations.sigmoid_beta(self.x, beta=torch.tensor(1))
desired = torch.sigmoid(self.x)
mismatch = np.testing.assert_array_almost_equal(actual,
desired,
decimal=5)
self.assertIsNone(mismatch)
def test_swish_beta1(self):
actual = activations.swish_beta(self.x, beta=torch.tensor(1))
desired = self.x * torch.sigmoid(self.x)
mismatch = np.testing.assert_array_almost_equal(actual,
desired,
decimal=5)
self.assertIsNone(mismatch)
def tearDown(self):
del self.x
class TestCompetitions(unittest.TestCase):
def setUp(self):
pass
def test_wtac(self):
d = torch.tensor([[2., 3., 1.99, 3.01], [2., 3., 2.01, 3.]])
labels = torch.tensor([0, 1, 2, 3])
actual = competitions.wtac(d, labels)
desired = torch.tensor([2, 0])
mismatch = np.testing.assert_array_almost_equal(actual,
desired,
decimal=5)
self.assertIsNone(mismatch)
def test_wtac_unequal_dist(self):
d = torch.tensor([[2., 3., 4.], [2., 3., 1.]])
labels = torch.tensor([0, 1, 1])
actual = competitions.wtac(d, labels)
desired = torch.tensor([0, 1])
mismatch = np.testing.assert_array_almost_equal(actual,
desired,
decimal=5)
self.assertIsNone(mismatch)
def test_wtac_one_hot(self):
d = torch.tensor([[1.99, 3.01], [3., 2.01]])
labels = torch.tensor([[0, 1], [1, 0]])
actual = competitions.wtac(d, labels)
desired = torch.tensor([[0, 1], [1, 0]])
mismatch = np.testing.assert_array_almost_equal(actual,
desired,
decimal=5)
self.assertIsNone(mismatch)
def test_stratified_min(self):
d = torch.tensor([[1., 0., 2., 3.], [9., 8., 0, 1]])
labels = torch.tensor([0, 0, 1, 2])
actual = competitions.stratified_min(d, labels)
desired = torch.tensor([[0., 2., 3.], [8., 0., 1.]])
mismatch = np.testing.assert_array_almost_equal(actual,
desired,
decimal=5)
self.assertIsNone(mismatch)
def test_stratified_min_one_hot(self):
d = torch.tensor([[1., 0., 2., 3.], [9., 8., 0, 1]])
labels = torch.tensor([0, 0, 1, 2])
labels = torch.eye(3)[labels]
actual = competitions.stratified_min(d, labels)
desired = torch.tensor([[0., 2., 3.], [8., 0., 1.]])
mismatch = np.testing.assert_array_almost_equal(actual,
desired,
decimal=5)
self.assertIsNone(mismatch)
def test_stratified_min_simple(self):
d = torch.tensor([[0., 2., 3.], [8., 0, 1]])
labels = torch.tensor([0, 1, 2])
actual = competitions.stratified_min(d, labels)
desired = torch.tensor([[0., 2., 3.], [8., 0., 1.]])
mismatch = np.testing.assert_array_almost_equal(actual,
desired,
decimal=5)
self.assertIsNone(mismatch)
def test_knnc_k1(self):
d = torch.tensor([[2., 3., 1.99, 3.01], [2., 3., 2.01, 3.]])
labels = torch.tensor([0, 1, 2, 3])
actual = competitions.knnc(d, labels, k=torch.tensor([1]))
desired = torch.tensor([2, 0])
mismatch = np.testing.assert_array_almost_equal(actual,
desired,
decimal=5)
self.assertIsNone(mismatch)
def tearDown(self):
pass
class TestDistances(unittest.TestCase): class TestDistances(unittest.TestCase):
@@ -167,103 +308,12 @@ class TestDistances(unittest.TestCase):
del self.x, self.y del self.x, self.y
class TestActivations(unittest.TestCase):
def setUp(self):
self.x = torch.randn(1024, 1)
def test_registry(self):
self.assertIsNotNone(activations.ACTIVATIONS)
def test_funcname_deserialization(self):
flist = ['identity', 'sigmoid_beta', 'swish_beta']
for funcname in flist:
f = activations.get_activation(funcname)
iscallable = callable(f)
self.assertTrue(iscallable)
def test_callable_deserialization(self):
def dummy(x, **kwargs):
return x
for f in [dummy, lambda x: x]:
f = activations.get_activation(f)
iscallable = callable(f)
self.assertTrue(iscallable)
self.assertEqual(1, f(1))
def test_unknown_deserialization(self):
for funcname in ['blubb', 'foobar']:
with self.assertRaises(NameError):
_ = activations.get_activation(funcname)
def test_identity(self):
actual = activations.identity(self.x)
desired = self.x
mismatch = np.testing.assert_array_almost_equal(actual,
desired,
decimal=5)
self.assertIsNone(mismatch)
def test_sigmoid_beta1(self):
actual = activations.sigmoid_beta(self.x, beta=1)
desired = torch.sigmoid(self.x)
mismatch = np.testing.assert_array_almost_equal(actual,
desired,
decimal=5)
self.assertIsNone(mismatch)
def test_swish_beta1(self):
actual = activations.swish_beta(self.x, beta=1)
desired = self.x * torch.sigmoid(self.x)
mismatch = np.testing.assert_array_almost_equal(actual,
desired,
decimal=5)
self.assertIsNone(mismatch)
def tearDown(self):
del self.x
class TestCompetitions(unittest.TestCase):
def setUp(self):
pass
def test_wtac(self):
d = torch.tensor([[2., 3., 1.99, 3.01], [2., 3., 2.01, 3.]])
labels = torch.tensor([0, 1, 2, 3])
actual = competitions.wtac(d, labels)
desired = torch.tensor([2, 0])
mismatch = np.testing.assert_array_almost_equal(actual,
desired,
decimal=5)
self.assertIsNone(mismatch)
def test_wtac_one_hot(self):
d = torch.tensor([[1.99, 3.01], [3., 2.01]])
labels = torch.tensor([[0, 1], [1, 0]])
actual = competitions.wtac(d, labels)
desired = torch.tensor([[0, 1], [1, 0]])
mismatch = np.testing.assert_array_almost_equal(actual,
desired,
decimal=5)
self.assertIsNone(mismatch)
def test_knnc_k1(self):
d = torch.tensor([[2., 3., 1.99, 3.01], [2., 3., 2.01, 3.]])
labels = torch.tensor([0, 1, 2, 3])
actual = competitions.knnc(d, labels, k=1)
desired = torch.tensor([2, 0])
mismatch = np.testing.assert_array_almost_equal(actual,
desired,
decimal=5)
self.assertIsNone(mismatch)
def tearDown(self):
pass
class TestInitializers(unittest.TestCase): class TestInitializers(unittest.TestCase):
def setUp(self): def setUp(self):
self.flist = [
'zeros', 'ones', 'rand', 'randn', 'stratified_mean',
'stratified_random'
]
self.x = torch.tensor( self.x = torch.tensor(
[[0, -1, -2], [10, 11, 12], [0, 0, 0], [2, 2, 2]], [[0, -1, -2], [10, 11, 12], [0, 0, 0], [2, 2, 2]],
dtype=torch.float32) dtype=torch.float32)
@@ -274,11 +324,7 @@ class TestInitializers(unittest.TestCase):
self.assertIsNotNone(initializers.INITIALIZERS) self.assertIsNotNone(initializers.INITIALIZERS)
def test_funcname_deserialization(self): def test_funcname_deserialization(self):
flist = [ for funcname in self.flist:
'zeros', 'ones', 'rand', 'randn', 'stratified_mean',
'stratified_random'
]
for funcname in flist:
f = initializers.get_initializer(funcname) f = initializers.get_initializer(funcname)
iscallable = callable(f) iscallable = callable(f)
self.assertTrue(iscallable) self.assertTrue(iscallable)
@@ -336,7 +382,7 @@ class TestInitializers(unittest.TestCase):
def test_stratified_mean_equal1(self): def test_stratified_mean_equal1(self):
pdist = torch.tensor([1, 1]) pdist = torch.tensor([1, 1])
actual, _ = initializers.stratified_mean(self.x, self.y, pdist) actual, _ = initializers.stratified_mean(self.x, self.y, pdist, False)
desired = torch.tensor([[5., 5., 5.], [1., 1., 1.]]) desired = torch.tensor([[5., 5., 5.], [1., 1., 1.]])
mismatch = np.testing.assert_array_almost_equal(actual, mismatch = np.testing.assert_array_almost_equal(actual,
desired, desired,
@@ -345,8 +391,9 @@ class TestInitializers(unittest.TestCase):
def test_stratified_random_equal1(self): def test_stratified_random_equal1(self):
pdist = torch.tensor([1, 1]) pdist = torch.tensor([1, 1])
actual, _ = initializers.stratified_random(self.x, self.y, pdist) actual, _ = initializers.stratified_random(self.x, self.y, pdist,
desired = torch.tensor([[0., -1., -2.], [2., 2., 2.]]) False)
desired = torch.tensor([[0., -1., -2.], [0., 0., 0.]])
mismatch = np.testing.assert_array_almost_equal(actual, mismatch = np.testing.assert_array_almost_equal(actual,
desired, desired,
decimal=5) decimal=5)
@@ -354,7 +401,7 @@ class TestInitializers(unittest.TestCase):
def test_stratified_mean_equal2(self): def test_stratified_mean_equal2(self):
pdist = torch.tensor([2, 2]) pdist = torch.tensor([2, 2])
actual, _ = initializers.stratified_mean(self.x, self.y, pdist) actual, _ = initializers.stratified_mean(self.x, self.y, pdist, False)
desired = torch.tensor([[5., 5., 5.], [5., 5., 5.], [1., 1., 1.], desired = torch.tensor([[5., 5., 5.], [5., 5., 5.], [1., 1., 1.],
[1., 1., 1.]]) [1., 1., 1.]])
mismatch = np.testing.assert_array_almost_equal(actual, mismatch = np.testing.assert_array_almost_equal(actual,
@@ -362,9 +409,20 @@ class TestInitializers(unittest.TestCase):
decimal=5) decimal=5)
self.assertIsNone(mismatch) self.assertIsNone(mismatch)
def test_stratified_random_equal2(self):
pdist = torch.tensor([2, 2])
actual, _ = initializers.stratified_random(self.x, self.y, pdist,
False)
desired = torch.tensor([[0., -1., -2.], [0., -1., -2.], [0., 0., 0.],
[0., 0., 0.]])
mismatch = np.testing.assert_array_almost_equal(actual,
desired,
decimal=5)
self.assertIsNone(mismatch)
def test_stratified_mean_unequal(self): def test_stratified_mean_unequal(self):
pdist = torch.tensor([1, 3]) pdist = torch.tensor([1, 3])
actual, _ = initializers.stratified_mean(self.x, self.y, pdist) actual, _ = initializers.stratified_mean(self.x, self.y, pdist, False)
desired = torch.tensor([[5., 5., 5.], [1., 1., 1.], [1., 1., 1.], desired = torch.tensor([[5., 5., 5.], [1., 1., 1.], [1., 1., 1.],
[1., 1., 1.]]) [1., 1., 1.]])
mismatch = np.testing.assert_array_almost_equal(actual, mismatch = np.testing.assert_array_almost_equal(actual,
@@ -374,14 +432,86 @@ class TestInitializers(unittest.TestCase):
def test_stratified_random_unequal(self): def test_stratified_random_unequal(self):
pdist = torch.tensor([1, 3]) pdist = torch.tensor([1, 3])
actual, _ = initializers.stratified_random(self.x, self.y, pdist) actual, _ = initializers.stratified_random(self.x, self.y, pdist,
desired = torch.tensor([[0., -1., -2.], [2., 2., 2.], [0., 0., 0.], False)
desired = torch.tensor([[0., -1., -2.], [0., 0., 0.], [0., 0., 0.],
[0., 0., 0.]]) [0., 0., 0.]])
mismatch = np.testing.assert_array_almost_equal(actual, mismatch = np.testing.assert_array_almost_equal(actual,
desired, desired,
decimal=5) decimal=5)
self.assertIsNone(mismatch) self.assertIsNone(mismatch)
def test_stratified_mean_unequal_one_hot(self):
pdist = torch.tensor([1, 3])
y = torch.eye(2)[self.y]
desired1 = torch.tensor([[5., 5., 5.], [1., 1., 1.], [1., 1., 1.],
[1., 1., 1.]])
actual1, actual2 = initializers.stratified_mean(self.x, y, pdist)
desired2 = torch.tensor([[1, 0], [0, 1], [0, 1], [0, 1]])
mismatch = np.testing.assert_array_almost_equal(actual1,
desired1,
decimal=5)
mismatch = np.testing.assert_array_almost_equal(actual2,
desired2,
decimal=5)
self.assertIsNone(mismatch)
def test_stratified_random_unequal_one_hot(self):
pdist = torch.tensor([1, 3])
y = torch.eye(2)[self.y]
actual1, actual2 = initializers.stratified_random(self.x, y, pdist)
desired1 = torch.tensor([[0., -1., -2.], [0., 0., 0.], [0., 0., 0.],
[0., 0., 0.]])
desired2 = torch.tensor([[1, 0], [0, 1], [0, 1], [0, 1]])
mismatch = np.testing.assert_array_almost_equal(actual1,
desired1,
decimal=5)
mismatch = np.testing.assert_array_almost_equal(actual2,
desired2,
decimal=5)
self.assertIsNone(mismatch)
def tearDown(self): def tearDown(self):
del self.x, self.y, self.gen del self.x, self.y, self.gen
_ = torch.seed() _ = torch.seed()
class TestLosses(unittest.TestCase):
def setUp(self):
pass
def test_glvq_loss_int_labels(self):
d = torch.stack([torch.ones(100), torch.zeros(100)], dim=1)
labels = torch.tensor([0, 1])
targets = torch.ones(100)
batch_loss = losses.glvq_loss(distances=d,
target_labels=targets,
prototype_labels=labels)
loss_value = torch.sum(batch_loss, dim=0)
self.assertEqual(loss_value, -100)
def test_glvq_loss_one_hot_labels(self):
d = torch.stack([torch.ones(100), torch.zeros(100)], dim=1)
labels = torch.tensor([[0, 1], [1, 0]])
wl = torch.tensor([1, 0])
targets = torch.stack([wl for _ in range(100)], dim=0)
batch_loss = losses.glvq_loss(distances=d,
target_labels=targets,
prototype_labels=labels)
loss_value = torch.sum(batch_loss, dim=0)
self.assertEqual(loss_value, -100)
def test_glvq_loss_one_hot_unequal(self):
dlist = [torch.ones(100), torch.zeros(100), torch.zeros(100)]
d = torch.stack(dlist, dim=1)
labels = torch.tensor([[0, 1], [1, 0], [1, 0]])
wl = torch.tensor([1, 0])
targets = torch.stack([wl for _ in range(100)], dim=0)
batch_loss = losses.glvq_loss(distances=d,
target_labels=targets,
prototype_labels=labels)
loss_value = torch.sum(batch_loss, dim=0)
self.assertEqual(loss_value, -100)
def tearDown(self):
pass

View File

@@ -5,7 +5,7 @@ import unittest
import numpy as np import numpy as np
import torch import torch
from prototorch.modules import prototypes, losses from prototorch.modules import losses, prototypes
class TestPrototypes(unittest.TestCase): class TestPrototypes(unittest.TestCase):
@@ -16,19 +16,23 @@ class TestPrototypes(unittest.TestCase):
self.y = torch.tensor([0, 0, 1, 1]) self.y = torch.tensor([0, 0, 1, 1])
self.gen = torch.manual_seed(42) self.gen = torch.manual_seed(42)
def test_addprototypes1d_init_without_input_dim(self): def test_prototypes1d_init_without_input_dim(self):
with self.assertRaises(NameError): with self.assertRaises(NameError):
_ = prototypes.AddPrototypes1D(nclasses=1) _ = prototypes.Prototypes1D(nclasses=2)
def test_addprototypes1d_init_without_nclasses(self): def test_prototypes1d_init_without_nclasses(self):
with self.assertRaises(NameError): with self.assertRaises(NameError):
_ = prototypes.AddPrototypes1D(input_dim=1) _ = prototypes.Prototypes1D(input_dim=1)
def test_addprototypes1d_init_without_pdist(self): def test_prototypes1d_init_with_nclasses_1(self):
p1 = prototypes.AddPrototypes1D(input_dim=6, with self.assertWarns(UserWarning):
nclasses=2, _ = prototypes.Prototypes1D(nclasses=1, input_dim=1)
prototypes_per_class=4,
prototype_initializer='ones') def test_prototypes1d_init_without_pdist(self):
p1 = prototypes.Prototypes1D(input_dim=6,
nclasses=2,
prototypes_per_class=4,
prototype_initializer='ones')
protos = p1.prototypes protos = p1.prototypes
actual = protos.detach().numpy() actual = protos.detach().numpy()
desired = torch.ones(8, 6) desired = torch.ones(8, 6)
@@ -37,11 +41,11 @@ class TestPrototypes(unittest.TestCase):
decimal=5) decimal=5)
self.assertIsNone(mismatch) self.assertIsNone(mismatch)
def test_addprototypes1d_init_without_data(self): def test_prototypes1d_init_without_data(self):
pdist = [2, 2] pdist = [2, 2]
p1 = prototypes.AddPrototypes1D(input_dim=3, p1 = prototypes.Prototypes1D(input_dim=3,
prototype_distribution=pdist, prototype_distribution=pdist,
prototype_initializer='zeros') prototype_initializer='zeros')
protos = p1.prototypes protos = p1.prototypes
actual = protos.detach().numpy() actual = protos.detach().numpy()
desired = torch.zeros(4, 3) desired = torch.zeros(4, 3)
@@ -50,23 +54,20 @@ class TestPrototypes(unittest.TestCase):
decimal=5) decimal=5)
self.assertIsNone(mismatch) self.assertIsNone(mismatch)
# def test_addprototypes1d_init_torch_pdist(self): def test_prototypes1d_proto_init_without_data(self):
# pdist = torch.tensor([2, 2]) with self.assertWarns(UserWarning):
# p1 = prototypes.AddPrototypes1D(input_dim=3, _ = prototypes.Prototypes1D(
# prototype_distribution=pdist, input_dim=3,
# prototype_initializer='zeros') nclasses=2,
# protos = p1.prototypes prototypes_per_class=1,
# actual = protos.detach().numpy() prototype_initializer='stratified_mean',
# desired = torch.zeros(4, 3) data=None)
# mismatch = np.testing.assert_array_almost_equal(actual,
# desired,
# decimal=5)
# self.assertIsNone(mismatch)
def test_addprototypes1d_init_with_ppc(self): def test_prototypes1d_init_torch_pdist(self):
p1 = prototypes.AddPrototypes1D(data=[self.x, self.y], pdist = torch.tensor([2, 2])
prototypes_per_class=2, p1 = prototypes.Prototypes1D(input_dim=3,
prototype_initializer='zeros') prototype_distribution=pdist,
prototype_initializer='zeros')
protos = p1.prototypes protos = p1.prototypes
actual = protos.detach().numpy() actual = protos.detach().numpy()
desired = torch.zeros(4, 3) desired = torch.zeros(4, 3)
@@ -75,10 +76,119 @@ class TestPrototypes(unittest.TestCase):
decimal=5) decimal=5)
self.assertIsNone(mismatch) self.assertIsNone(mismatch)
def test_addprototypes1d_init_with_pdist(self): def test_prototypes1d_init_without_inputdim_with_data(self):
p1 = prototypes.AddPrototypes1D(data=[self.x, self.y], _ = prototypes.Prototypes1D(nclasses=2,
prototype_distribution=[6, 9], prototypes_per_class=1,
prototype_initializer='zeros') prototype_initializer='stratified_mean',
data=[[[1.], [0.]], [1, 0]])
def test_prototypes1d_init_with_int_data(self):
_ = prototypes.Prototypes1D(nclasses=2,
prototypes_per_class=1,
prototype_initializer='stratified_mean',
data=[[[1], [0]], [1, 0]])
def test_prototypes1d_init_one_hot_without_data(self):
_ = prototypes.Prototypes1D(input_dim=1,
nclasses=2,
prototypes_per_class=1,
prototype_initializer='stratified_mean',
data=None,
one_hot_labels=True)
def test_prototypes1d_init_one_hot_labels_false(self):
"""Test if ValueError is raised when `one_hot_labels` is set to `False`
but the provided `data` has one-hot encoded labels.
"""
with self.assertRaises(ValueError):
_ = prototypes.Prototypes1D(
input_dim=1,
nclasses=2,
prototypes_per_class=1,
prototype_initializer='stratified_mean',
data=([[0.], [1.]], [[0, 1], [1, 0]]),
one_hot_labels=False)
def test_prototypes1d_init_1d_y_data_one_hot_labels_true(self):
"""Test if ValueError is raised when `one_hot_labels` is set to `True`
but the provided `data` does not contain one-hot encoded labels.
"""
with self.assertRaises(ValueError):
_ = prototypes.Prototypes1D(
input_dim=1,
nclasses=2,
prototypes_per_class=1,
prototype_initializer='stratified_mean',
data=([[0.], [1.]], [0, 1]),
one_hot_labels=True)
def test_prototypes1d_init_one_hot_labels_true(self):
"""Test if ValueError is raised when `one_hot_labels` is set to `True`
but the provided `data` contains 2D targets but
does not contain one-hot encoded labels.
"""
with self.assertRaises(ValueError):
_ = prototypes.Prototypes1D(
input_dim=1,
nclasses=2,
prototypes_per_class=1,
prototype_initializer='stratified_mean',
data=([[0.], [1.]], [[0], [1]]),
one_hot_labels=True)
def test_prototypes1d_init_with_int_dtype(self):
with self.assertRaises(RuntimeError):
_ = prototypes.Prototypes1D(
nclasses=2,
prototypes_per_class=1,
prototype_initializer='stratified_mean',
data=[[[1], [0]], [1, 0]],
dtype=torch.int32)
def test_prototypes1d_inputndim_with_data(self):
with self.assertRaises(ValueError):
_ = prototypes.Prototypes1D(input_dim=1,
nclasses=1,
prototypes_per_class=1,
data=[[1.], [1]])
def test_prototypes1d_inputdim_with_data(self):
with self.assertRaises(ValueError):
_ = prototypes.Prototypes1D(
input_dim=2,
nclasses=2,
prototypes_per_class=1,
prototype_initializer='stratified_mean',
data=[[[1.], [0.]], [1, 0]])
def test_prototypes1d_nclasses_with_data(self):
"""Test ValueError raise if provided `nclasses` is not the same
as the one computed from the provided `data`.
"""
with self.assertRaises(ValueError):
_ = prototypes.Prototypes1D(
input_dim=1,
nclasses=1,
prototypes_per_class=1,
prototype_initializer='stratified_mean',
data=[[[1.], [2.]], [1, 2]])
def test_prototypes1d_init_with_ppc(self):
p1 = prototypes.Prototypes1D(data=[self.x, self.y],
prototypes_per_class=2,
prototype_initializer='zeros')
protos = p1.prototypes
actual = protos.detach().numpy()
desired = torch.zeros(4, 3)
mismatch = np.testing.assert_array_almost_equal(actual,
desired,
decimal=5)
self.assertIsNone(mismatch)
def test_prototypes1d_init_with_pdist(self):
p1 = prototypes.Prototypes1D(data=[self.x, self.y],
prototype_distribution=[6, 9],
prototype_initializer='zeros')
protos = p1.prototypes protos = p1.prototypes
actual = protos.detach().numpy() actual = protos.detach().numpy()
desired = torch.zeros(15, 3) desired = torch.zeros(15, 3)
@@ -87,14 +197,14 @@ class TestPrototypes(unittest.TestCase):
decimal=5) decimal=5)
self.assertIsNone(mismatch) self.assertIsNone(mismatch)
def test_addprototypes1d_func_initializer(self): def test_prototypes1d_func_initializer(self):
def my_initializer(*args, **kwargs): def my_initializer(*args, **kwargs):
return torch.full((2, 99), 99), torch.tensor([0, 1]) return torch.full((2, 99), 99), torch.tensor([0, 1])
p1 = prototypes.AddPrototypes1D(input_dim=99, p1 = prototypes.Prototypes1D(input_dim=99,
nclasses=2, nclasses=2,
prototypes_per_class=1, prototypes_per_class=1,
prototype_initializer=my_initializer) prototype_initializer=my_initializer)
protos = p1.prototypes protos = p1.prototypes
actual = protos.detach().numpy() actual = protos.detach().numpy()
desired = 99 * torch.ones(2, 99) desired = 99 * torch.ones(2, 99)
@@ -103,8 +213,8 @@ class TestPrototypes(unittest.TestCase):
decimal=5) decimal=5)
self.assertIsNone(mismatch) self.assertIsNone(mismatch)
def test_addprototypes1d_forward(self): def test_prototypes1d_forward(self):
p1 = prototypes.AddPrototypes1D(data=[self.x, self.y]) p1 = prototypes.Prototypes1D(data=[self.x, self.y])
protos, _ = p1() protos, _ = p1()
actual = protos.detach().numpy() actual = protos.detach().numpy()
desired = torch.ones(2, 3) desired = torch.ones(2, 3)
@@ -113,6 +223,16 @@ class TestPrototypes(unittest.TestCase):
decimal=5) decimal=5)
self.assertIsNone(mismatch) self.assertIsNone(mismatch)
def test_prototypes1d_dist_validate(self):
p1 = prototypes.Prototypes1D(input_dim=0, prototype_distribution=[0])
with self.assertWarns(UserWarning):
_ = p1._validate_prototype_distribution()
def test_prototypes1d_validate_extra_repr_not_empty(self):
p1 = prototypes.Prototypes1D(input_dim=0, prototype_distribution=[0])
rep = p1.extra_repr()
self.assertNotEqual(rep, '')
def tearDown(self): def tearDown(self):
del self.x, self.y, self.gen del self.x, self.y, self.gen
_ = torch.seed() _ = torch.seed()
@@ -123,7 +243,19 @@ class TestLosses(unittest.TestCase):
pass pass
def test_glvqloss_init(self): def test_glvqloss_init(self):
_ = losses.GLVQLoss() _ = losses.GLVQLoss(0, 'swish_beta', beta=20)
def test_glvqloss_forward(self):
criterion = losses.GLVQLoss(margin=0,
squashing='sigmoid_beta',
beta=100)
d = torch.stack([torch.ones(100), torch.zeros(100)], dim=1)
labels = torch.tensor([0, 1])
targets = torch.ones(100)
outputs = [d, labels]
loss = criterion(outputs, targets)
loss_value = loss.item()
self.assertAlmostEqual(loss_value, 0.0)
def tearDown(self): def tearDown(self):
pass pass