Compare commits
18 Commits
v0.1.0-dev
...
v0.1.1-dev
Author | SHA1 | Date | |
---|---|---|---|
|
438a5b9360 | ||
|
f98f3d095e | ||
|
21b0279839 | ||
|
b19cbcb76a | ||
|
7d5ab81dbf | ||
|
bde408a80e | ||
|
900955d67a | ||
|
3757c937b3 | ||
|
38f637aaeb | ||
|
6ddfe48a95 | ||
|
bf0e694321 | ||
|
e2c9848120 | ||
|
dc60b7e5b5 | ||
|
c21913fdd4 | ||
|
59e31f94ab | ||
|
cddefa9b0d | ||
|
26d71fdd60 | ||
|
ced8f532dd |
@@ -1,5 +1,5 @@
|
|||||||
[bumpversion]
|
[bumpversion]
|
||||||
current_version = 0.1.0-dev0
|
current_version = 0.1.1-dev0
|
||||||
commit = True
|
commit = True
|
||||||
tag = True
|
tag = True
|
||||||
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)(\-(?P<release>[a-z]+)(?P<build>\d+))?
|
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)(\-(?P<release>[a-z]+)(?P<build>\d+))?
|
||||||
|
2
.github/workflows/pythonapp.yml
vendored
2
.github/workflows/pythonapp.yml
vendored
@@ -1,7 +1,7 @@
|
|||||||
# This workflow will install Python dependencies, run tests and lint with a single version of Python
|
# This workflow will install Python dependencies, run tests and lint with a single version of Python
|
||||||
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
|
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
|
||||||
|
|
||||||
name: Tests
|
name: tests
|
||||||
|
|
||||||
on:
|
on:
|
||||||
push:
|
push:
|
||||||
|
19
.travis.yml
Normal file
19
.travis.yml
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
dist: bionic
|
||||||
|
sudo: false
|
||||||
|
language: python
|
||||||
|
python: 3.8
|
||||||
|
cache:
|
||||||
|
directories:
|
||||||
|
- ./tests/artifacts
|
||||||
|
|
||||||
|
install:
|
||||||
|
- pip install . --progress-bar off
|
||||||
|
- pip install codecov
|
||||||
|
- pip install pytest
|
||||||
|
|
||||||
|
script:
|
||||||
|
- coverage run -m pytest
|
||||||
|
|
||||||
|
# Push the results to codecov
|
||||||
|
after_success:
|
||||||
|
- codecov
|
@@ -1,9 +1,11 @@
|
|||||||
include .bumpversion.cfg
|
include .bumpversion.cfg
|
||||||
include LICENSE
|
include LICENSE
|
||||||
include tox.ini
|
include tox.ini
|
||||||
|
include *.yml
|
||||||
recursive-include docs *.bat
|
recursive-include docs *.bat
|
||||||
recursive-include docs *.png
|
recursive-include docs *.png
|
||||||
recursive-include docs *.py
|
recursive-include docs *.py
|
||||||
recursive-include docs *.rst
|
recursive-include docs *.rst
|
||||||
recursive-include docs Makefile
|
recursive-include docs Makefile
|
||||||
recursive-include examples *.py
|
recursive-include examples *.py
|
||||||
|
recursive-include tests *.py
|
||||||
|
20
README.md
20
README.md
@@ -3,8 +3,13 @@
|
|||||||
ProtoTorch is a PyTorch-based Python toolbox for bleeding-edge research in
|
ProtoTorch is a PyTorch-based Python toolbox for bleeding-edge research in
|
||||||
prototype-based machine learning algorithms.
|
prototype-based machine learning algorithms.
|
||||||
|
|
||||||

|
[](https://travis-ci.org/si-cim/prototorch)
|
||||||
|
[](https://badge.fury.io/gh/si-cim%2Fprototorch)
|
||||||
|
[](https://badge.fury.io/py/prototorch)
|
||||||
|

|
||||||
[](https://codecov.io/gh/si-cim/prototorch)
|
[](https://codecov.io/gh/si-cim/prototorch)
|
||||||
|
[](https://pepy.tech/project/prototorch)
|
||||||
|
[](https://github.com/si-cim/prototorch/blob/master/LICENSE)
|
||||||
|
|
||||||
## Description
|
## Description
|
||||||
|
|
||||||
@@ -47,3 +52,16 @@ API, with more algorithms and techniques coming soon. If you would simply like
|
|||||||
to be able to use those algorithms to train large ML models on a GPU, ProtoTorch
|
to be able to use those algorithms to train large ML models on a GPU, ProtoTorch
|
||||||
lets you do this without requiring a black-belt in high-performance Tensor
|
lets you do this without requiring a black-belt in high-performance Tensor
|
||||||
computation.
|
computation.
|
||||||
|
|
||||||
|
## Bibtex
|
||||||
|
|
||||||
|
If you would like to cite the package, please use this:
|
||||||
|
```bibtex
|
||||||
|
@misc{Ravichandran2020,
|
||||||
|
author = {Ravichandran, J},
|
||||||
|
title = {ProtoTorch},
|
||||||
|
year = {2020},
|
||||||
|
publisher = {GitHub},
|
||||||
|
journal = {GitHub repository},
|
||||||
|
howpublished = {\url{https://github.com/si-cim/prototorch}}
|
||||||
|
}
|
||||||
|
3
RELEASE.md
Normal file
3
RELEASE.md
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
# Release 0.1.0-dev0
|
||||||
|
|
||||||
|
Initial public release of ProtoTorch.
|
@@ -1 +1 @@
|
|||||||
__version__ = '0.1.0-dev0'
|
__version__ = '0.1.1-dev0'
|
||||||
|
@@ -5,30 +5,36 @@ import torch
|
|||||||
ACTIVATIONS = dict()
|
ACTIVATIONS = dict()
|
||||||
|
|
||||||
|
|
||||||
def register_activation(func):
|
# def register_activation(scriptf):
|
||||||
ACTIVATIONS[func.__name__] = func
|
# ACTIVATIONS[scriptf.name] = scriptf
|
||||||
return func
|
# return scriptf
|
||||||
|
def register_activation(f):
|
||||||
|
ACTIVATIONS[f.__name__] = f
|
||||||
|
return f
|
||||||
|
|
||||||
|
|
||||||
@register_activation
|
@register_activation
|
||||||
def identity(input, **kwargs):
|
# @torch.jit.script
|
||||||
|
def identity(input, beta=torch.tensor([0])):
|
||||||
""":math:`f(x) = x`"""
|
""":math:`f(x) = x`"""
|
||||||
return input
|
return input
|
||||||
|
|
||||||
|
|
||||||
@register_activation
|
@register_activation
|
||||||
def sigmoid_beta(input, beta=10):
|
# @torch.jit.script
|
||||||
|
def sigmoid_beta(input, beta=torch.tensor([10])):
|
||||||
""":math:`f(x) = \\frac{1}{1 + e^{-\\beta x}}`
|
""":math:`f(x) = \\frac{1}{1 + e^{-\\beta x}}`
|
||||||
|
|
||||||
Keyword Arguments:
|
Keyword Arguments:
|
||||||
beta (float): Parameter :math:`\\beta`
|
beta (float): Parameter :math:`\\beta`
|
||||||
"""
|
"""
|
||||||
out = torch.reciprocal(1.0 + torch.exp(-beta * input))
|
out = torch.reciprocal(1.0 + torch.exp(-int(beta.item()) * input))
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
@register_activation
|
@register_activation
|
||||||
def swish_beta(input, beta=10):
|
# @torch.jit.script
|
||||||
|
def swish_beta(input, beta=torch.tensor([10])):
|
||||||
""":math:`f(x) = \\frac{x}{1 + e^{-\\beta x}}`
|
""":math:`f(x) = \\frac{x}{1 + e^{-\\beta x}}`
|
||||||
|
|
||||||
Keyword Arguments:
|
Keyword Arguments:
|
||||||
|
@@ -3,13 +3,15 @@
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
# @torch.jit.script
|
||||||
def wtac(distances, labels):
|
def wtac(distances, labels):
|
||||||
winning_indices = torch.min(distances, dim=1).indices
|
winning_indices = torch.min(distances, dim=1).indices
|
||||||
winning_labels = labels[winning_indices].squeeze()
|
winning_labels = labels[winning_indices].squeeze()
|
||||||
return winning_labels
|
return winning_labels
|
||||||
|
|
||||||
|
|
||||||
|
# @torch.jit.script
|
||||||
def knnc(distances, labels, k):
|
def knnc(distances, labels, k):
|
||||||
winning_indices = torch.topk(-distances, k=k, dim=1).indices
|
winning_indices = torch.topk(-distances, k=k.item(), dim=1).indices
|
||||||
winning_labels = labels[winning_indices].squeeze()
|
winning_labels = labels[winning_indices].squeeze()
|
||||||
return winning_labels
|
return winning_labels
|
||||||
|
@@ -33,13 +33,6 @@ def lpnorm_distance(x, y, p):
|
|||||||
Expected dimension of x is 2.
|
Expected dimension of x is 2.
|
||||||
Expected dimension of y is 2.
|
Expected dimension of y is 2.
|
||||||
"""
|
"""
|
||||||
# # DEPRECATED in favor of torch.cdist
|
|
||||||
# expanded_x = x.unsqueeze(dim=1)
|
|
||||||
# batchwise_difference = y - expanded_x
|
|
||||||
# differences_raised = torch.pow(batchwise_difference, p)
|
|
||||||
# distances_raised = torch.sum(differences_raised, axis=2)
|
|
||||||
# distances = torch.pow(distances_raised, 1.0 / p)
|
|
||||||
# return distances
|
|
||||||
distances = torch.cdist(x, y, p=p)
|
distances = torch.cdist(x, y, p=p)
|
||||||
return distances
|
return distances
|
||||||
|
|
||||||
|
@@ -12,12 +12,9 @@ def glvq_loss(distances, target_labels, prototype_labels):
|
|||||||
matcher = torch.eq(torch.sum(matcher, dim=-1), nclasses)
|
matcher = torch.eq(torch.sum(matcher, dim=-1), nclasses)
|
||||||
not_matcher = torch.bitwise_not(matcher)
|
not_matcher = torch.bitwise_not(matcher)
|
||||||
|
|
||||||
dplus_criterion = distances * matcher > 0.0
|
|
||||||
dminus_criterion = distances * not_matcher > 0.0
|
|
||||||
|
|
||||||
inf = torch.full_like(distances, fill_value=float('inf'))
|
inf = torch.full_like(distances, fill_value=float('inf'))
|
||||||
distances_to_wpluses = torch.where(dplus_criterion, distances, inf)
|
distances_to_wpluses = torch.where(matcher, distances, inf)
|
||||||
distances_to_wminuses = torch.where(dminus_criterion, distances, inf)
|
distances_to_wminuses = torch.where(not_matcher, distances, inf)
|
||||||
dpluses = torch.min(distances_to_wpluses, dim=1, keepdim=True).values
|
dpluses = torch.min(distances_to_wpluses, dim=1, keepdim=True).values
|
||||||
dminuses = torch.min(distances_to_wminuses, dim=1, keepdim=True).values
|
dminuses = torch.min(distances_to_wminuses, dim=1, keepdim=True).values
|
||||||
|
|
||||||
|
@@ -12,7 +12,7 @@ class GLVQLoss(torch.nn.Module):
|
|||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self.margin = margin
|
self.margin = margin
|
||||||
self.squashing = get_activation(squashing)
|
self.squashing = get_activation(squashing)
|
||||||
self.beta = beta
|
self.beta = torch.tensor(beta)
|
||||||
|
|
||||||
def forward(self, outputs, targets):
|
def forward(self, outputs, targets):
|
||||||
distances, plabels = outputs
|
distances, plabels = outputs
|
||||||
|
2
setup.py
2
setup.py
@@ -10,7 +10,7 @@ with open('README.md', 'r') as fh:
|
|||||||
long_description = fh.read()
|
long_description = fh.read()
|
||||||
|
|
||||||
setup(name='prototorch',
|
setup(name='prototorch',
|
||||||
version='0.1.0-dev0',
|
version='0.1.1-dev0',
|
||||||
description='Highly extensible, GPU-supported '
|
description='Highly extensible, GPU-supported '
|
||||||
'Learning Vector Quantization (LVQ) toolbox '
|
'Learning Vector Quantization (LVQ) toolbox '
|
||||||
'built using PyTorch and its nn API.',
|
'built using PyTorch and its nn API.',
|
||||||
|
0
tests/__init__.py
Normal file
0
tests/__init__.py
Normal file
@@ -6,7 +6,107 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from prototorch.functions import (activations, competitions, distances,
|
from prototorch.functions import (activations, competitions, distances,
|
||||||
initializers)
|
initializers, losses)
|
||||||
|
|
||||||
|
|
||||||
|
class TestActivations(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
self.flist = ['identity', 'sigmoid_beta', 'swish_beta']
|
||||||
|
self.x = torch.randn(1024, 1)
|
||||||
|
|
||||||
|
def test_registry(self):
|
||||||
|
self.assertIsNotNone(activations.ACTIVATIONS)
|
||||||
|
|
||||||
|
def test_funcname_deserialization(self):
|
||||||
|
for funcname in self.flist:
|
||||||
|
f = activations.get_activation(funcname)
|
||||||
|
iscallable = callable(f)
|
||||||
|
self.assertTrue(iscallable)
|
||||||
|
|
||||||
|
# def test_torch_script(self):
|
||||||
|
# for funcname in self.flist:
|
||||||
|
# f = activations.get_activation(funcname)
|
||||||
|
# self.assertIsInstance(f, torch.jit.ScriptFunction)
|
||||||
|
|
||||||
|
def test_callable_deserialization(self):
|
||||||
|
def dummy(x, **kwargs):
|
||||||
|
return x
|
||||||
|
|
||||||
|
for f in [dummy, lambda x: x]:
|
||||||
|
f = activations.get_activation(f)
|
||||||
|
iscallable = callable(f)
|
||||||
|
self.assertTrue(iscallable)
|
||||||
|
self.assertEqual(1, f(1))
|
||||||
|
|
||||||
|
def test_unknown_deserialization(self):
|
||||||
|
for funcname in ['blubb', 'foobar']:
|
||||||
|
with self.assertRaises(NameError):
|
||||||
|
_ = activations.get_activation(funcname)
|
||||||
|
|
||||||
|
def test_identity(self):
|
||||||
|
actual = activations.identity(self.x)
|
||||||
|
desired = self.x
|
||||||
|
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||||
|
desired,
|
||||||
|
decimal=5)
|
||||||
|
self.assertIsNone(mismatch)
|
||||||
|
|
||||||
|
def test_sigmoid_beta1(self):
|
||||||
|
actual = activations.sigmoid_beta(self.x, beta=torch.tensor(1))
|
||||||
|
desired = torch.sigmoid(self.x)
|
||||||
|
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||||
|
desired,
|
||||||
|
decimal=5)
|
||||||
|
self.assertIsNone(mismatch)
|
||||||
|
|
||||||
|
def test_swish_beta1(self):
|
||||||
|
actual = activations.swish_beta(self.x, beta=torch.tensor(1))
|
||||||
|
desired = self.x * torch.sigmoid(self.x)
|
||||||
|
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||||
|
desired,
|
||||||
|
decimal=5)
|
||||||
|
self.assertIsNone(mismatch)
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
del self.x
|
||||||
|
|
||||||
|
|
||||||
|
class TestCompetitions(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def test_wtac(self):
|
||||||
|
d = torch.tensor([[2., 3., 1.99, 3.01], [2., 3., 2.01, 3.]])
|
||||||
|
labels = torch.tensor([0, 1, 2, 3])
|
||||||
|
actual = competitions.wtac(d, labels)
|
||||||
|
desired = torch.tensor([2, 0])
|
||||||
|
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||||
|
desired,
|
||||||
|
decimal=5)
|
||||||
|
self.assertIsNone(mismatch)
|
||||||
|
|
||||||
|
def test_wtac_one_hot(self):
|
||||||
|
d = torch.tensor([[1.99, 3.01], [3., 2.01]])
|
||||||
|
labels = torch.tensor([[0, 1], [1, 0]])
|
||||||
|
actual = competitions.wtac(d, labels)
|
||||||
|
desired = torch.tensor([[0, 1], [1, 0]])
|
||||||
|
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||||
|
desired,
|
||||||
|
decimal=5)
|
||||||
|
self.assertIsNone(mismatch)
|
||||||
|
|
||||||
|
def test_knnc_k1(self):
|
||||||
|
d = torch.tensor([[2., 3., 1.99, 3.01], [2., 3., 2.01, 3.]])
|
||||||
|
labels = torch.tensor([0, 1, 2, 3])
|
||||||
|
actual = competitions.knnc(d, labels, k=torch.tensor([1]))
|
||||||
|
desired = torch.tensor([2, 0])
|
||||||
|
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||||
|
desired,
|
||||||
|
decimal=5)
|
||||||
|
self.assertIsNone(mismatch)
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class TestDistances(unittest.TestCase):
|
class TestDistances(unittest.TestCase):
|
||||||
@@ -167,103 +267,12 @@ class TestDistances(unittest.TestCase):
|
|||||||
del self.x, self.y
|
del self.x, self.y
|
||||||
|
|
||||||
|
|
||||||
class TestActivations(unittest.TestCase):
|
|
||||||
def setUp(self):
|
|
||||||
self.x = torch.randn(1024, 1)
|
|
||||||
|
|
||||||
def test_registry(self):
|
|
||||||
self.assertIsNotNone(activations.ACTIVATIONS)
|
|
||||||
|
|
||||||
def test_funcname_deserialization(self):
|
|
||||||
flist = ['identity', 'sigmoid_beta', 'swish_beta']
|
|
||||||
for funcname in flist:
|
|
||||||
f = activations.get_activation(funcname)
|
|
||||||
iscallable = callable(f)
|
|
||||||
self.assertTrue(iscallable)
|
|
||||||
|
|
||||||
def test_callable_deserialization(self):
|
|
||||||
def dummy(x, **kwargs):
|
|
||||||
return x
|
|
||||||
|
|
||||||
for f in [dummy, lambda x: x]:
|
|
||||||
f = activations.get_activation(f)
|
|
||||||
iscallable = callable(f)
|
|
||||||
self.assertTrue(iscallable)
|
|
||||||
self.assertEqual(1, f(1))
|
|
||||||
|
|
||||||
def test_unknown_deserialization(self):
|
|
||||||
for funcname in ['blubb', 'foobar']:
|
|
||||||
with self.assertRaises(NameError):
|
|
||||||
_ = activations.get_activation(funcname)
|
|
||||||
|
|
||||||
def test_identity(self):
|
|
||||||
actual = activations.identity(self.x)
|
|
||||||
desired = self.x
|
|
||||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
|
||||||
desired,
|
|
||||||
decimal=5)
|
|
||||||
self.assertIsNone(mismatch)
|
|
||||||
|
|
||||||
def test_sigmoid_beta1(self):
|
|
||||||
actual = activations.sigmoid_beta(self.x, beta=1)
|
|
||||||
desired = torch.sigmoid(self.x)
|
|
||||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
|
||||||
desired,
|
|
||||||
decimal=5)
|
|
||||||
self.assertIsNone(mismatch)
|
|
||||||
|
|
||||||
def test_swish_beta1(self):
|
|
||||||
actual = activations.swish_beta(self.x, beta=1)
|
|
||||||
desired = self.x * torch.sigmoid(self.x)
|
|
||||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
|
||||||
desired,
|
|
||||||
decimal=5)
|
|
||||||
self.assertIsNone(mismatch)
|
|
||||||
|
|
||||||
def tearDown(self):
|
|
||||||
del self.x
|
|
||||||
|
|
||||||
|
|
||||||
class TestCompetitions(unittest.TestCase):
|
|
||||||
def setUp(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def test_wtac(self):
|
|
||||||
d = torch.tensor([[2., 3., 1.99, 3.01], [2., 3., 2.01, 3.]])
|
|
||||||
labels = torch.tensor([0, 1, 2, 3])
|
|
||||||
actual = competitions.wtac(d, labels)
|
|
||||||
desired = torch.tensor([2, 0])
|
|
||||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
|
||||||
desired,
|
|
||||||
decimal=5)
|
|
||||||
self.assertIsNone(mismatch)
|
|
||||||
|
|
||||||
def test_wtac_one_hot(self):
|
|
||||||
d = torch.tensor([[1.99, 3.01], [3., 2.01]])
|
|
||||||
labels = torch.tensor([[0, 1], [1, 0]])
|
|
||||||
actual = competitions.wtac(d, labels)
|
|
||||||
desired = torch.tensor([[0, 1], [1, 0]])
|
|
||||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
|
||||||
desired,
|
|
||||||
decimal=5)
|
|
||||||
self.assertIsNone(mismatch)
|
|
||||||
|
|
||||||
def test_knnc_k1(self):
|
|
||||||
d = torch.tensor([[2., 3., 1.99, 3.01], [2., 3., 2.01, 3.]])
|
|
||||||
labels = torch.tensor([0, 1, 2, 3])
|
|
||||||
actual = competitions.knnc(d, labels, k=1)
|
|
||||||
desired = torch.tensor([2, 0])
|
|
||||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
|
||||||
desired,
|
|
||||||
decimal=5)
|
|
||||||
self.assertIsNone(mismatch)
|
|
||||||
|
|
||||||
def tearDown(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class TestInitializers(unittest.TestCase):
|
class TestInitializers(unittest.TestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
|
self.flist = [
|
||||||
|
'zeros', 'ones', 'rand', 'randn', 'stratified_mean',
|
||||||
|
'stratified_random'
|
||||||
|
]
|
||||||
self.x = torch.tensor(
|
self.x = torch.tensor(
|
||||||
[[0, -1, -2], [10, 11, 12], [0, 0, 0], [2, 2, 2]],
|
[[0, -1, -2], [10, 11, 12], [0, 0, 0], [2, 2, 2]],
|
||||||
dtype=torch.float32)
|
dtype=torch.float32)
|
||||||
@@ -274,11 +283,7 @@ class TestInitializers(unittest.TestCase):
|
|||||||
self.assertIsNotNone(initializers.INITIALIZERS)
|
self.assertIsNotNone(initializers.INITIALIZERS)
|
||||||
|
|
||||||
def test_funcname_deserialization(self):
|
def test_funcname_deserialization(self):
|
||||||
flist = [
|
for funcname in self.flist:
|
||||||
'zeros', 'ones', 'rand', 'randn', 'stratified_mean',
|
|
||||||
'stratified_random'
|
|
||||||
]
|
|
||||||
for funcname in flist:
|
|
||||||
f = initializers.get_initializer(funcname)
|
f = initializers.get_initializer(funcname)
|
||||||
iscallable = callable(f)
|
iscallable = callable(f)
|
||||||
self.assertTrue(iscallable)
|
self.assertTrue(iscallable)
|
||||||
@@ -385,3 +390,32 @@ class TestInitializers(unittest.TestCase):
|
|||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
del self.x, self.y, self.gen
|
del self.x, self.y, self.gen
|
||||||
_ = torch.seed()
|
_ = torch.seed()
|
||||||
|
|
||||||
|
|
||||||
|
class TestLosses(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def test_glvq_loss_int_labels(self):
|
||||||
|
d = torch.stack([torch.ones(100), torch.zeros(100)], dim=1)
|
||||||
|
labels = torch.tensor([0, 1])
|
||||||
|
targets = torch.ones(100)
|
||||||
|
batch_loss = losses.glvq_loss(distances=d,
|
||||||
|
target_labels=targets,
|
||||||
|
prototype_labels=labels)
|
||||||
|
loss_value = torch.sum(batch_loss, dim=0)
|
||||||
|
self.assertEqual(loss_value, -100)
|
||||||
|
|
||||||
|
def test_glvq_loss_one_hot_labels(self):
|
||||||
|
d = torch.stack([torch.ones(100), torch.zeros(100)], dim=1)
|
||||||
|
labels = torch.tensor([[0, 1], [1, 0]])
|
||||||
|
wl = torch.tensor([1, 0])
|
||||||
|
targets = torch.stack([wl for _ in range(100)], dim=0)
|
||||||
|
batch_loss = losses.glvq_loss(distances=d,
|
||||||
|
target_labels=targets,
|
||||||
|
prototype_labels=labels)
|
||||||
|
loss_value = torch.sum(batch_loss, dim=0)
|
||||||
|
self.assertEqual(loss_value, -100)
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
pass
|
||||||
|
@@ -123,7 +123,19 @@ class TestLosses(unittest.TestCase):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
def test_glvqloss_init(self):
|
def test_glvqloss_init(self):
|
||||||
_ = losses.GLVQLoss()
|
_ = losses.GLVQLoss(0, 'swish_beta', beta=20)
|
||||||
|
|
||||||
|
def test_glvqloss_forward(self):
|
||||||
|
criterion = losses.GLVQLoss(margin=0,
|
||||||
|
squashing='sigmoid_beta',
|
||||||
|
beta=100)
|
||||||
|
d = torch.stack([torch.ones(100), torch.zeros(100)], dim=1)
|
||||||
|
labels = torch.tensor([0, 1])
|
||||||
|
targets = torch.ones(100)
|
||||||
|
outputs = [d, labels]
|
||||||
|
loss = criterion(outputs, targets)
|
||||||
|
loss_value = loss.item()
|
||||||
|
self.assertAlmostEqual(loss_value, 0.0)
|
||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
pass
|
pass
|
||||||
|
10
tox.ini
10
tox.ini
@@ -4,12 +4,12 @@
|
|||||||
# and then run "tox" from this directory.
|
# and then run "tox" from this directory.
|
||||||
|
|
||||||
[tox]
|
[tox]
|
||||||
envlist = py36
|
envlist = py36,py37,py38
|
||||||
|
|
||||||
[testenv]
|
[testenv]
|
||||||
deps =
|
deps =
|
||||||
numpy
|
pytest
|
||||||
unittest-xml-reporting
|
coverage
|
||||||
commands =
|
commands =
|
||||||
python -m xmlrunner -o reports
|
pip install -e .
|
||||||
|
coverage run -m pytest
|
Reference in New Issue
Block a user