778 lines
28 KiB
Python
778 lines
28 KiB
Python
"""ProtoTorch core test suite"""
|
|
|
|
import unittest
|
|
|
|
import numpy as np
|
|
import pytest
|
|
import torch
|
|
|
|
import prototorch as pt
|
|
from prototorch.utils import parse_distribution
|
|
|
|
|
|
# Utils
|
|
def test_parse_distribution_dict_0():
|
|
distribution = {"num_classes": 1, "per_class": 0}
|
|
distribution = parse_distribution(distribution)
|
|
assert distribution == {0: 0}
|
|
|
|
|
|
def test_parse_distribution_dict_1():
|
|
distribution = dict(num_classes=3, per_class=2)
|
|
distribution = parse_distribution(distribution)
|
|
assert distribution == {0: 2, 1: 2, 2: 2}
|
|
|
|
|
|
def test_parse_distribution_dict_2():
|
|
distribution = {0: 1, 2: 2, -1: 3}
|
|
distribution = parse_distribution(distribution)
|
|
assert distribution == {0: 1, 2: 2, -1: 3}
|
|
|
|
|
|
def test_parse_distribution_tuple():
|
|
distribution = (2, 3)
|
|
distribution = parse_distribution(distribution)
|
|
assert distribution == {0: 3, 1: 3}
|
|
|
|
|
|
def test_parse_distribution_list():
|
|
distribution = [1, 1, 0, 2]
|
|
distribution = parse_distribution(distribution)
|
|
assert distribution == {0: 1, 1: 1, 2: 0, 3: 2}
|
|
|
|
|
|
def test_parse_distribution_custom_labels():
|
|
distribution = [1, 1, 0, 2]
|
|
clabels = [1, 2, 5, 3]
|
|
distribution = parse_distribution(distribution, clabels)
|
|
assert distribution == {1: 1, 2: 1, 5: 0, 3: 2}
|
|
|
|
|
|
# Components initializers
|
|
def test_literal_comp_generate():
|
|
protos = torch.rand(4, 3, 5, 5)
|
|
c = pt.initializers.LiteralCompInitializer(protos)
|
|
components = c.generate([])
|
|
assert torch.allclose(components, protos)
|
|
|
|
|
|
def test_literal_comp_generate_from_list():
|
|
protos = [[0, 1], [2, 3], [4, 5]]
|
|
c = pt.initializers.LiteralCompInitializer(protos)
|
|
with pytest.warns(UserWarning):
|
|
components = c.generate([])
|
|
assert torch.allclose(components, torch.Tensor(protos))
|
|
|
|
|
|
def test_shape_aware_raises_error():
|
|
with pytest.raises(TypeError):
|
|
_ = pt.initializers.ShapeAwareCompInitializer(shape=(2, ))
|
|
|
|
|
|
def test_data_aware_comp_generate():
|
|
protos = torch.rand(4, 3, 5, 5)
|
|
c = pt.initializers.DataAwareCompInitializer(protos)
|
|
components = c.generate(num_components="IgnoreMe!")
|
|
assert torch.allclose(components, protos)
|
|
|
|
|
|
def test_class_aware_comp_generate():
|
|
protos = torch.rand(4, 2, 3, 5, 5)
|
|
plabels = torch.tensor([0, 0, 1, 1]).long()
|
|
c = pt.initializers.ClassAwareCompInitializer([protos, plabels])
|
|
components = c.generate(distribution=[])
|
|
assert torch.allclose(components, protos)
|
|
|
|
|
|
def test_zeros_comp_generate():
|
|
shape = (3, 5, 5)
|
|
c = pt.initializers.ZerosCompInitializer(shape)
|
|
components = c.generate(num_components=4)
|
|
assert torch.allclose(components, torch.zeros(4, 3, 5, 5))
|
|
|
|
|
|
def test_ones_comp_generate():
|
|
c = pt.initializers.OnesCompInitializer(2)
|
|
components = c.generate(num_components=3)
|
|
assert torch.allclose(components, torch.ones(3, 2))
|
|
|
|
|
|
def test_fill_value_comp_generate():
|
|
c = pt.initializers.FillValueCompInitializer(2, 0.0)
|
|
components = c.generate(num_components=3)
|
|
assert torch.allclose(components, torch.zeros(3, 2))
|
|
|
|
|
|
def test_uniform_comp_generate_min_max_bound():
|
|
c = pt.initializers.UniformCompInitializer(2, -1.0, 1.0)
|
|
components = c.generate(num_components=1024)
|
|
assert components.min() >= -1.0
|
|
assert components.max() <= 1.0
|
|
|
|
|
|
def test_random_comp_generate_mean():
|
|
c = pt.initializers.RandomNormalCompInitializer(2, -1.0)
|
|
components = c.generate(num_components=1024)
|
|
assert torch.allclose(components.mean(),
|
|
torch.tensor(-1.0),
|
|
rtol=1e-05,
|
|
atol=1e-01)
|
|
|
|
|
|
def test_comp_generate_0_components():
|
|
c = pt.initializers.ZerosCompInitializer(2)
|
|
_ = c.generate(num_components=0)
|
|
|
|
|
|
def test_stratified_mean_comp_generate():
|
|
# yapf: disable
|
|
x = torch.Tensor(
|
|
[[0, -1, -2],
|
|
[10, 11, 12],
|
|
[0, 0, 0],
|
|
[2, 2, 2]])
|
|
y = torch.LongTensor([0, 0, 1, 1])
|
|
desired = torch.Tensor(
|
|
[[5.0, 5.0, 5.0],
|
|
[1.0, 1.0, 1.0]])
|
|
# yapf: enable
|
|
c = pt.initializers.StratifiedMeanCompInitializer(data=[x, y])
|
|
actual = c.generate([1, 1])
|
|
assert torch.allclose(actual, desired)
|
|
|
|
|
|
def test_stratified_selection_comp_generate():
|
|
# yapf: disable
|
|
x = torch.Tensor(
|
|
[[0, 0, 0],
|
|
[1, 1, 1],
|
|
[0, 0, 0],
|
|
[1, 1, 1]])
|
|
y = torch.LongTensor([0, 1, 0, 1])
|
|
desired = torch.Tensor(
|
|
[[0, 0, 0],
|
|
[1, 1, 1]])
|
|
# yapf: enable
|
|
c = pt.initializers.StratifiedSelectionCompInitializer(data=[x, y])
|
|
actual = c.generate([1, 1])
|
|
assert torch.allclose(actual, desired)
|
|
|
|
|
|
# Labels initializers
|
|
def test_literal_labels_init():
|
|
l = pt.initializers.LiteralLabelsInitializer([0, 0, 1, 2])
|
|
with pytest.warns(UserWarning):
|
|
labels = l.generate([])
|
|
assert torch.allclose(labels, torch.LongTensor([0, 0, 1, 2]))
|
|
|
|
|
|
def test_labels_init_from_list():
|
|
l = pt.initializers.LabelsInitializer()
|
|
components = l.generate(distribution=[1, 1, 1])
|
|
assert torch.allclose(components, torch.LongTensor([0, 1, 2]))
|
|
|
|
|
|
def test_labels_init_from_tuple_legal():
|
|
l = pt.initializers.LabelsInitializer()
|
|
components = l.generate(distribution=(3, 1))
|
|
assert torch.allclose(components, torch.LongTensor([0, 1, 2]))
|
|
|
|
|
|
def test_labels_init_from_tuple_illegal():
|
|
l = pt.initializers.LabelsInitializer()
|
|
with pytest.raises(AssertionError):
|
|
_ = l.generate(distribution=(1, 1, 1))
|
|
|
|
|
|
def test_data_aware_labels_init():
|
|
data, targets = [0, 1, 2, 3], [0, 0, 1, 1]
|
|
ds = pt.datasets.NumpyDataset(data, targets)
|
|
l = pt.initializers.DataAwareLabelsInitializer(ds)
|
|
labels = l.generate([])
|
|
assert torch.allclose(labels, torch.LongTensor(targets))
|
|
|
|
|
|
# Reasonings initializers
|
|
def test_literal_reasonings_init():
|
|
r = pt.initializers.LiteralReasoningsInitializer([0, 0, 1, 2])
|
|
with pytest.warns(UserWarning):
|
|
reasonings = r.generate([])
|
|
assert torch.allclose(reasonings, torch.Tensor([0, 0, 1, 2]))
|
|
|
|
|
|
def test_random_reasonings_init():
|
|
r = pt.initializers.RandomReasoningsInitializer(0.2, 0.8)
|
|
reasonings = r.generate(distribution=[0, 1])
|
|
assert torch.numel(reasonings) == 1 * 2 * 2
|
|
assert reasonings.min() >= 0.2
|
|
assert reasonings.max() <= 0.8
|
|
|
|
|
|
def test_zeros_reasonings_init():
|
|
r = pt.initializers.ZerosReasoningsInitializer()
|
|
reasonings = r.generate(distribution=[0, 1])
|
|
assert torch.allclose(reasonings, torch.zeros(1, 2, 2))
|
|
|
|
|
|
def test_ones_reasonings_init():
|
|
r = pt.initializers.ZerosReasoningsInitializer()
|
|
reasonings = r.generate(distribution=[1, 2, 3])
|
|
assert torch.allclose(reasonings, torch.zeros(6, 3, 2))
|
|
|
|
|
|
def test_pure_positive_reasonings_init_one_per_class():
|
|
r = pt.initializers.PurePositiveReasoningsInitializer(
|
|
components_first=False)
|
|
reasonings = r.generate(distribution=(4, 1))
|
|
assert torch.allclose(reasonings[0], torch.eye(4))
|
|
|
|
|
|
def test_pure_positive_reasonings_init_unrepresented_classes():
|
|
r = pt.initializers.PurePositiveReasoningsInitializer()
|
|
reasonings = r.generate(distribution=[9, 0, 0, 0])
|
|
assert reasonings.shape[0] == 9
|
|
assert reasonings.shape[1] == 4
|
|
assert reasonings.shape[2] == 2
|
|
|
|
|
|
def test_random_reasonings_init_channels_not_first():
|
|
r = pt.initializers.RandomReasoningsInitializer(components_first=False)
|
|
reasonings = r.generate(distribution=[0, 0, 0, 1])
|
|
assert reasonings.shape[0] == 2
|
|
assert reasonings.shape[1] == 4
|
|
assert reasonings.shape[2] == 1
|
|
|
|
|
|
# Transform initializers
|
|
def test_eye_transform_init_square():
|
|
t = pt.initializers.EyeLinearTransformInitializer()
|
|
I = t.generate(3, 3)
|
|
assert torch.allclose(I, torch.eye(3))
|
|
|
|
|
|
def test_eye_transform_init_narrow():
|
|
t = pt.initializers.EyeLinearTransformInitializer()
|
|
actual = t.generate(3, 2)
|
|
desired = torch.Tensor([[1, 0], [0, 1], [0, 0]])
|
|
assert torch.allclose(actual, desired)
|
|
|
|
|
|
def test_eye_transform_init_wide():
|
|
t = pt.initializers.EyeLinearTransformInitializer()
|
|
actual = t.generate(2, 3)
|
|
desired = torch.Tensor([[1, 0, 0], [0, 1, 0]])
|
|
assert torch.allclose(actual, desired)
|
|
|
|
|
|
# Transforms
|
|
def test_linear_transform_default_eye_init():
|
|
l = pt.transforms.LinearTransform(2, 4)
|
|
actual = l.weights
|
|
desired = torch.Tensor([[1, 0, 0, 0], [0, 1, 0, 0]])
|
|
assert torch.allclose(actual, desired)
|
|
|
|
|
|
def test_linear_transform_forward():
|
|
l = pt.transforms.LinearTransform(4, 2)
|
|
actual_weights = l.weights
|
|
desired_weights = torch.Tensor([[1, 0], [0, 1], [0, 0], [0, 0]])
|
|
assert torch.allclose(actual_weights, desired_weights)
|
|
actual_outputs = l(torch.Tensor([[1.1, 2.2, 3.3, 4.4], \
|
|
[1.1, 2.2, 3.3, 4.4], \
|
|
[5.5, 6.6, 7.7, 8.8]]))
|
|
desired_outputs = torch.Tensor([[1.1, 2.2], [1.1, 2.2], [5.5, 6.6]])
|
|
assert torch.allclose(actual_outputs, desired_outputs)
|
|
|
|
|
|
def test_linear_transform_zeros_init():
|
|
l = pt.transforms.LinearTransform(
|
|
in_dim=2,
|
|
out_dim=4,
|
|
initializer=pt.initializers.ZerosLinearTransformInitializer(),
|
|
)
|
|
actual = l.weights
|
|
desired = torch.zeros(2, 4)
|
|
assert torch.allclose(actual, desired)
|
|
|
|
|
|
def test_linear_transform_out_dim_first():
|
|
l = pt.transforms.LinearTransform(
|
|
in_dim=2,
|
|
out_dim=4,
|
|
initializer=pt.initializers.OLTI(out_dim_first=True),
|
|
)
|
|
assert l.weights.shape[0] == 4
|
|
assert l.weights.shape[1] == 2
|
|
|
|
|
|
# Components
|
|
def test_components_no_initializer():
|
|
with pytest.raises(TypeError):
|
|
_ = pt.components.Components(3, None)
|
|
|
|
|
|
def test_components_no_num_components():
|
|
with pytest.raises(TypeError):
|
|
_ = pt.components.Components(initializer=pt.initializers.OCI(2))
|
|
|
|
|
|
def test_components_none_num_components():
|
|
with pytest.raises(TypeError):
|
|
_ = pt.components.Components(None, initializer=pt.initializers.OCI(2))
|
|
|
|
|
|
def test_components_no_args():
|
|
with pytest.raises(TypeError):
|
|
_ = pt.components.Components()
|
|
|
|
|
|
def test_components_zeros_init():
|
|
c = pt.components.Components(3, pt.initializers.ZCI(2))
|
|
assert torch.allclose(c.components, torch.zeros(3, 2))
|
|
|
|
|
|
def test_labeled_components_dict_init():
|
|
c = pt.components.LabeledComponents({0: 3}, pt.initializers.OCI(2))
|
|
assert torch.allclose(c.components, torch.ones(3, 2))
|
|
assert torch.allclose(c.labels, torch.zeros(3, dtype=torch.long))
|
|
|
|
|
|
def test_labeled_components_list_init():
|
|
c = pt.components.LabeledComponents([3], pt.initializers.OCI(2))
|
|
assert torch.allclose(c.components, torch.ones(3, 2))
|
|
assert torch.allclose(c.labels, torch.zeros(3, dtype=torch.long))
|
|
|
|
|
|
def test_labeled_components_tuple_init():
|
|
c = pt.components.LabeledComponents({0: 1, 1: 2}, pt.initializers.OCI(2))
|
|
assert torch.allclose(c.components, torch.ones(3, 2))
|
|
assert torch.allclose(c.labels, torch.LongTensor([0, 1, 1]))
|
|
|
|
|
|
# Labels
|
|
def test_standalone_labels_dict_init():
|
|
l = pt.components.Labels({0: 3})
|
|
assert torch.allclose(l.labels, torch.zeros(3, dtype=torch.long))
|
|
|
|
|
|
def test_standalone_labels_list_init():
|
|
l = pt.components.Labels([3])
|
|
assert torch.allclose(l.labels, torch.zeros(3, dtype=torch.long))
|
|
|
|
|
|
def test_standalone_labels_tuple_init():
|
|
l = pt.components.Labels({0: 1, 1: 2})
|
|
assert torch.allclose(l.labels, torch.LongTensor([0, 1, 1]))
|
|
|
|
|
|
# Losses
|
|
def test_glvq_loss_int_labels():
|
|
d = torch.stack([torch.ones(100), torch.zeros(100)], dim=1)
|
|
labels = torch.tensor([0, 1])
|
|
targets = torch.ones(100)
|
|
batch_loss = pt.losses.glvq_loss(distances=d,
|
|
target_labels=targets,
|
|
prototype_labels=labels)
|
|
loss_value = torch.sum(batch_loss, dim=0)
|
|
assert loss_value == -100
|
|
|
|
|
|
def test_glvq_loss_one_hot_labels():
|
|
d = torch.stack([torch.ones(100), torch.zeros(100)], dim=1)
|
|
labels = torch.tensor([[0, 1], [1, 0]])
|
|
wl = torch.tensor([1, 0])
|
|
targets = torch.stack([wl for _ in range(100)], dim=0)
|
|
batch_loss = pt.losses.glvq_loss(distances=d,
|
|
target_labels=targets,
|
|
prototype_labels=labels)
|
|
loss_value = torch.sum(batch_loss, dim=0)
|
|
assert loss_value == -100
|
|
|
|
|
|
def test_glvq_loss_one_hot_unequal():
|
|
dlist = [torch.ones(100), torch.zeros(100), torch.zeros(100)]
|
|
d = torch.stack(dlist, dim=1)
|
|
labels = torch.tensor([[0, 1], [1, 0], [1, 0]])
|
|
wl = torch.tensor([1, 0])
|
|
targets = torch.stack([wl for _ in range(100)], dim=0)
|
|
batch_loss = pt.losses.glvq_loss(distances=d,
|
|
target_labels=targets,
|
|
prototype_labels=labels)
|
|
loss_value = torch.sum(batch_loss, dim=0)
|
|
assert loss_value == -100
|
|
|
|
|
|
# Activations
|
|
class TestActivations(unittest.TestCase):
|
|
|
|
def setUp(self):
|
|
self.flist = ["identity", "sigmoid_beta", "swish_beta"]
|
|
self.x = torch.randn(1024, 1)
|
|
|
|
def test_registry(self):
|
|
self.assertIsNotNone(pt.nn.ACTIVATIONS)
|
|
|
|
def test_funcname_deserialization(self):
|
|
for funcname in self.flist:
|
|
f = pt.nn.get_activation(funcname)
|
|
iscallable = callable(f)
|
|
self.assertTrue(iscallable)
|
|
|
|
def test_callable_deserialization(self):
|
|
|
|
def dummy(x, **kwargs):
|
|
return x
|
|
|
|
for f in [dummy, lambda x: x]:
|
|
f = pt.nn.get_activation(f)
|
|
iscallable = callable(f)
|
|
self.assertTrue(iscallable)
|
|
self.assertEqual(1, f(1))
|
|
|
|
def test_unknown_deserialization(self):
|
|
for funcname in ["blubb", "foobar"]:
|
|
with self.assertRaises(NameError):
|
|
_ = pt.nn.get_activation(funcname)
|
|
|
|
def test_identity(self):
|
|
actual = pt.nn.identity(self.x)
|
|
desired = self.x
|
|
mismatch = np.testing.assert_array_almost_equal(actual,
|
|
desired,
|
|
decimal=5)
|
|
self.assertIsNone(mismatch)
|
|
|
|
def test_sigmoid_beta1(self):
|
|
actual = pt.nn.sigmoid_beta(self.x, beta=1.0)
|
|
desired = torch.sigmoid(self.x)
|
|
mismatch = np.testing.assert_array_almost_equal(actual,
|
|
desired,
|
|
decimal=5)
|
|
self.assertIsNone(mismatch)
|
|
|
|
def test_swish_beta1(self):
|
|
actual = pt.nn.swish_beta(self.x, beta=1.0)
|
|
desired = self.x * torch.sigmoid(self.x)
|
|
mismatch = np.testing.assert_array_almost_equal(actual,
|
|
desired,
|
|
decimal=5)
|
|
self.assertIsNone(mismatch)
|
|
|
|
def tearDown(self):
|
|
del self.x
|
|
|
|
|
|
# Competitions
|
|
class TestCompetitions(unittest.TestCase):
|
|
|
|
def setUp(self):
|
|
pass
|
|
|
|
def test_wtac(self):
|
|
d = torch.tensor([[2.0, 3.0, 1.99, 3.01], [2.0, 3.0, 2.01, 3.0]])
|
|
labels = torch.tensor([0, 1, 2, 3])
|
|
competition_layer = pt.competitions.WTAC()
|
|
actual = competition_layer(d, labels)
|
|
desired = torch.tensor([2, 0])
|
|
mismatch = np.testing.assert_array_almost_equal(actual,
|
|
desired,
|
|
decimal=5)
|
|
self.assertIsNone(mismatch)
|
|
|
|
def test_wtac_unequal_dist(self):
|
|
d = torch.tensor([[2.0, 3.0, 4.0], [2.0, 3.0, 1.0]])
|
|
labels = torch.tensor([0, 1, 1])
|
|
competition_layer = pt.competitions.WTAC()
|
|
actual = competition_layer(d, labels)
|
|
desired = torch.tensor([0, 1])
|
|
mismatch = np.testing.assert_array_almost_equal(actual,
|
|
desired,
|
|
decimal=5)
|
|
self.assertIsNone(mismatch)
|
|
|
|
def test_wtac_one_hot(self):
|
|
d = torch.tensor([[1.99, 3.01], [3.0, 2.01]])
|
|
labels = torch.tensor([[0, 1], [1, 0]])
|
|
competition_layer = pt.competitions.WTAC()
|
|
actual = competition_layer(d, labels)
|
|
desired = torch.tensor([[0, 1], [1, 0]])
|
|
mismatch = np.testing.assert_array_almost_equal(actual,
|
|
desired,
|
|
decimal=5)
|
|
self.assertIsNone(mismatch)
|
|
|
|
def test_knnc_k1(self):
|
|
d = torch.tensor([[2.0, 3.0, 1.99, 3.01], [2.0, 3.0, 2.01, 3.0]])
|
|
labels = torch.tensor([0, 1, 2, 3])
|
|
competition_layer = pt.competitions.KNNC(k=1)
|
|
actual = competition_layer(d, labels)
|
|
desired = torch.tensor([2, 0])
|
|
mismatch = np.testing.assert_array_almost_equal(actual,
|
|
desired,
|
|
decimal=5)
|
|
self.assertIsNone(mismatch)
|
|
|
|
def tearDown(self):
|
|
pass
|
|
|
|
|
|
# Pooling
|
|
class TestPooling(unittest.TestCase):
|
|
|
|
def setUp(self):
|
|
pass
|
|
|
|
def test_stratified_min(self):
|
|
d = torch.tensor([[1.0, 0.0, 2.0, 3.0], [9.0, 8.0, 0, 1]])
|
|
labels = torch.tensor([0, 0, 1, 2])
|
|
pooling_layer = pt.pooling.StratifiedMinPooling()
|
|
actual = pooling_layer(d, labels)
|
|
desired = torch.tensor([[0.0, 2.0, 3.0], [8.0, 0.0, 1.0]])
|
|
mismatch = np.testing.assert_array_almost_equal(actual,
|
|
desired,
|
|
decimal=5)
|
|
self.assertIsNone(mismatch)
|
|
|
|
def test_stratified_min_one_hot(self):
|
|
d = torch.tensor([[1.0, 0.0, 2.0, 3.0], [9.0, 8.0, 0, 1]])
|
|
labels = torch.tensor([0, 0, 1, 2])
|
|
labels = torch.eye(3)[labels]
|
|
pooling_layer = pt.pooling.StratifiedMinPooling()
|
|
actual = pooling_layer(d, labels)
|
|
desired = torch.tensor([[0.0, 2.0, 3.0], [8.0, 0.0, 1.0]])
|
|
mismatch = np.testing.assert_array_almost_equal(actual,
|
|
desired,
|
|
decimal=5)
|
|
self.assertIsNone(mismatch)
|
|
|
|
def test_stratified_min_trivial(self):
|
|
d = torch.tensor([[0.0, 2.0, 3.0], [8.0, 0, 1]])
|
|
labels = torch.tensor([0, 1, 2])
|
|
pooling_layer = pt.pooling.StratifiedMinPooling()
|
|
actual = pooling_layer(d, labels)
|
|
desired = torch.tensor([[0.0, 2.0, 3.0], [8.0, 0.0, 1.0]])
|
|
mismatch = np.testing.assert_array_almost_equal(actual,
|
|
desired,
|
|
decimal=5)
|
|
self.assertIsNone(mismatch)
|
|
|
|
def test_stratified_max(self):
|
|
d = torch.tensor([[1.0, 0.0, 2.0, 3.0, 9.0], [9.0, 8.0, 0, 1, 7.0]])
|
|
labels = torch.tensor([0, 0, 3, 2, 0])
|
|
pooling_layer = pt.pooling.StratifiedMaxPooling()
|
|
actual = pooling_layer(d, labels)
|
|
desired = torch.tensor([[9.0, 3.0, 2.0], [9.0, 1.0, 0.0]])
|
|
mismatch = np.testing.assert_array_almost_equal(actual,
|
|
desired,
|
|
decimal=5)
|
|
self.assertIsNone(mismatch)
|
|
|
|
def test_stratified_max_one_hot(self):
|
|
d = torch.tensor([[1.0, 0.0, 2.0, 3.0, 9.0], [9.0, 8.0, 0, 1, 7.0]])
|
|
labels = torch.tensor([0, 0, 2, 1, 0])
|
|
labels = torch.nn.functional.one_hot(labels, num_classes=3)
|
|
pooling_layer = pt.pooling.StratifiedMaxPooling()
|
|
actual = pooling_layer(d, labels)
|
|
desired = torch.tensor([[9.0, 3.0, 2.0], [9.0, 1.0, 0.0]])
|
|
mismatch = np.testing.assert_array_almost_equal(actual,
|
|
desired,
|
|
decimal=5)
|
|
self.assertIsNone(mismatch)
|
|
|
|
def test_stratified_sum(self):
|
|
d = torch.tensor([[1.0, 0.0, 2.0, 3.0], [9.0, 8.0, 0, 1]])
|
|
labels = torch.LongTensor([0, 0, 1, 2])
|
|
pooling_layer = pt.pooling.StratifiedSumPooling()
|
|
actual = pooling_layer(d, labels)
|
|
desired = torch.tensor([[1.0, 2.0, 3.0], [17.0, 0.0, 1.0]])
|
|
mismatch = np.testing.assert_array_almost_equal(actual,
|
|
desired,
|
|
decimal=5)
|
|
self.assertIsNone(mismatch)
|
|
|
|
def test_stratified_sum_one_hot(self):
|
|
d = torch.tensor([[1.0, 0.0, 2.0, 3.0], [9.0, 8.0, 0, 1]])
|
|
labels = torch.tensor([0, 0, 1, 2])
|
|
labels = torch.eye(3)[labels]
|
|
pooling_layer = pt.pooling.StratifiedSumPooling()
|
|
actual = pooling_layer(d, labels)
|
|
desired = torch.tensor([[1.0, 2.0, 3.0], [17.0, 0.0, 1.0]])
|
|
mismatch = np.testing.assert_array_almost_equal(actual,
|
|
desired,
|
|
decimal=5)
|
|
self.assertIsNone(mismatch)
|
|
|
|
def test_stratified_prod(self):
|
|
d = torch.tensor([[1.0, 0.0, 2.0, 3.0, 9.0], [9.0, 8.0, 0, 1, 7.0]])
|
|
labels = torch.tensor([0, 0, 3, 2, 0])
|
|
pooling_layer = pt.pooling.StratifiedProdPooling()
|
|
actual = pooling_layer(d, labels)
|
|
desired = torch.tensor([[0.0, 3.0, 2.0], [504.0, 1.0, 0.0]])
|
|
mismatch = np.testing.assert_array_almost_equal(actual,
|
|
desired,
|
|
decimal=5)
|
|
self.assertIsNone(mismatch)
|
|
|
|
def tearDown(self):
|
|
pass
|
|
|
|
|
|
# Distances
|
|
class TestDistances(unittest.TestCase):
|
|
|
|
def setUp(self):
|
|
self.nx, self.mx = 32, 2048
|
|
self.ny, self.my = 8, 2048
|
|
self.x = torch.randn(self.nx, self.mx)
|
|
self.y = torch.randn(self.ny, self.my)
|
|
|
|
def test_manhattan(self):
|
|
actual = pt.distances.lpnorm_distance(self.x, self.y, p=1)
|
|
desired = torch.empty(self.nx, self.ny)
|
|
for i in range(self.nx):
|
|
for j in range(self.ny):
|
|
desired[i][j] = torch.nn.functional.pairwise_distance(
|
|
self.x[i].reshape(1, -1),
|
|
self.y[j].reshape(1, -1),
|
|
p=1,
|
|
keepdim=False,
|
|
)
|
|
mismatch = np.testing.assert_array_almost_equal(actual,
|
|
desired,
|
|
decimal=2)
|
|
self.assertIsNone(mismatch)
|
|
|
|
def test_euclidean(self):
|
|
actual = pt.distances.euclidean_distance(self.x, self.y)
|
|
desired = torch.empty(self.nx, self.ny)
|
|
for i in range(self.nx):
|
|
for j in range(self.ny):
|
|
desired[i][j] = torch.nn.functional.pairwise_distance(
|
|
self.x[i].reshape(1, -1),
|
|
self.y[j].reshape(1, -1),
|
|
p=2,
|
|
keepdim=False,
|
|
)
|
|
mismatch = np.testing.assert_array_almost_equal(actual,
|
|
desired,
|
|
decimal=3)
|
|
self.assertIsNone(mismatch)
|
|
|
|
def test_squared_euclidean(self):
|
|
actual = pt.distances.squared_euclidean_distance(self.x, self.y)
|
|
desired = torch.empty(self.nx, self.ny)
|
|
for i in range(self.nx):
|
|
for j in range(self.ny):
|
|
desired[i][j] = (torch.nn.functional.pairwise_distance(
|
|
self.x[i].reshape(1, -1),
|
|
self.y[j].reshape(1, -1),
|
|
p=2,
|
|
keepdim=False,
|
|
)**2)
|
|
mismatch = np.testing.assert_array_almost_equal(actual,
|
|
desired,
|
|
decimal=2)
|
|
self.assertIsNone(mismatch)
|
|
|
|
def test_lpnorm_p0(self):
|
|
actual = pt.distances.lpnorm_distance(self.x, self.y, p=0)
|
|
desired = torch.empty(self.nx, self.ny)
|
|
for i in range(self.nx):
|
|
for j in range(self.ny):
|
|
desired[i][j] = torch.nn.functional.pairwise_distance(
|
|
self.x[i].reshape(1, -1),
|
|
self.y[j].reshape(1, -1),
|
|
p=0,
|
|
keepdim=False,
|
|
)
|
|
mismatch = np.testing.assert_array_almost_equal(actual,
|
|
desired,
|
|
decimal=4)
|
|
self.assertIsNone(mismatch)
|
|
|
|
def test_lpnorm_p2(self):
|
|
actual = pt.distances.lpnorm_distance(self.x, self.y, p=2)
|
|
desired = torch.empty(self.nx, self.ny)
|
|
for i in range(self.nx):
|
|
for j in range(self.ny):
|
|
desired[i][j] = torch.nn.functional.pairwise_distance(
|
|
self.x[i].reshape(1, -1),
|
|
self.y[j].reshape(1, -1),
|
|
p=2,
|
|
keepdim=False,
|
|
)
|
|
mismatch = np.testing.assert_array_almost_equal(actual,
|
|
desired,
|
|
decimal=4)
|
|
self.assertIsNone(mismatch)
|
|
|
|
def test_lpnorm_p3(self):
|
|
actual = pt.distances.lpnorm_distance(self.x, self.y, p=3)
|
|
desired = torch.empty(self.nx, self.ny)
|
|
for i in range(self.nx):
|
|
for j in range(self.ny):
|
|
desired[i][j] = torch.nn.functional.pairwise_distance(
|
|
self.x[i].reshape(1, -1),
|
|
self.y[j].reshape(1, -1),
|
|
p=3,
|
|
keepdim=False,
|
|
)
|
|
mismatch = np.testing.assert_array_almost_equal(actual,
|
|
desired,
|
|
decimal=4)
|
|
self.assertIsNone(mismatch)
|
|
|
|
def test_lpnorm_pinf(self):
|
|
actual = pt.distances.lpnorm_distance(self.x, self.y, p=float("inf"))
|
|
desired = torch.empty(self.nx, self.ny)
|
|
for i in range(self.nx):
|
|
for j in range(self.ny):
|
|
desired[i][j] = torch.nn.functional.pairwise_distance(
|
|
self.x[i].reshape(1, -1),
|
|
self.y[j].reshape(1, -1),
|
|
p=float("inf"),
|
|
keepdim=False,
|
|
)
|
|
mismatch = np.testing.assert_array_almost_equal(actual,
|
|
desired,
|
|
decimal=4)
|
|
self.assertIsNone(mismatch)
|
|
|
|
def test_omega_identity(self):
|
|
omega = torch.eye(self.mx, self.my)
|
|
actual = pt.distances.omega_distance(self.x, self.y, omega=omega)
|
|
desired = torch.empty(self.nx, self.ny)
|
|
for i in range(self.nx):
|
|
for j in range(self.ny):
|
|
desired[i][j] = (torch.nn.functional.pairwise_distance(
|
|
self.x[i].reshape(1, -1),
|
|
self.y[j].reshape(1, -1),
|
|
p=2,
|
|
keepdim=False,
|
|
)**2)
|
|
mismatch = np.testing.assert_array_almost_equal(actual,
|
|
desired,
|
|
decimal=2)
|
|
self.assertIsNone(mismatch)
|
|
|
|
def test_lomega_identity(self):
|
|
omega = torch.eye(self.mx, self.my)
|
|
omegas = torch.stack([omega for _ in range(self.ny)], dim=0)
|
|
actual = pt.distances.lomega_distance(self.x, self.y, omegas=omegas)
|
|
desired = torch.empty(self.nx, self.ny)
|
|
for i in range(self.nx):
|
|
for j in range(self.ny):
|
|
desired[i][j] = (torch.nn.functional.pairwise_distance(
|
|
self.x[i].reshape(1, -1),
|
|
self.y[j].reshape(1, -1),
|
|
p=2,
|
|
keepdim=False,
|
|
)**2)
|
|
mismatch = np.testing.assert_array_almost_equal(actual,
|
|
desired,
|
|
decimal=2)
|
|
self.assertIsNone(mismatch)
|
|
|
|
def tearDown(self):
|
|
del self.x, self.y
|