Remove Prototypes1D and its tests
This commit is contained in:
parent
40ef3aeda2
commit
87334c11e6
@ -1,7 +1 @@
|
|||||||
"""ProtoTorch modules."""
|
"""ProtoTorch modules."""
|
||||||
|
|
||||||
from .prototypes import Prototypes1D
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"Prototypes1D",
|
|
||||||
]
|
|
@ -1,137 +0,0 @@
|
|||||||
"""ProtoTorch prototype modules."""
|
|
||||||
|
|
||||||
import warnings
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from prototorch.functions.initializers import get_initializer
|
|
||||||
|
|
||||||
|
|
||||||
class _Prototypes(torch.nn.Module):
|
|
||||||
"""Abstract prototypes class."""
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
|
|
||||||
def _validate_prototype_distribution(self):
|
|
||||||
if 0 in self.prototype_distribution:
|
|
||||||
warnings.warn("Are you sure about the `0` in "
|
|
||||||
"`prototype_distribution`?")
|
|
||||||
|
|
||||||
def extra_repr(self):
|
|
||||||
return f"prototypes.shape: {tuple(self.prototypes.shape)}"
|
|
||||||
|
|
||||||
def forward(self):
|
|
||||||
return self.prototypes, self.prototype_labels
|
|
||||||
|
|
||||||
|
|
||||||
class Prototypes1D(_Prototypes):
|
|
||||||
"""Create a learnable set of one-dimensional prototypes.
|
|
||||||
|
|
||||||
TODO Complete this doc-string.
|
|
||||||
"""
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
prototypes_per_class=1,
|
|
||||||
prototype_initializer="ones",
|
|
||||||
prototype_distribution=None,
|
|
||||||
data=None,
|
|
||||||
dtype=torch.float32,
|
|
||||||
one_hot_labels=False,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
warnings.warn(
|
|
||||||
PendingDeprecationWarning(
|
|
||||||
"Prototypes1D will be replaced in future versions."))
|
|
||||||
|
|
||||||
# Convert tensors to python lists before processing
|
|
||||||
if prototype_distribution is not None:
|
|
||||||
if not isinstance(prototype_distribution, list):
|
|
||||||
prototype_distribution = prototype_distribution.tolist()
|
|
||||||
|
|
||||||
if data is None:
|
|
||||||
if "input_dim" not in kwargs:
|
|
||||||
raise NameError("`input_dim` required if "
|
|
||||||
"no `data` is provided.")
|
|
||||||
if prototype_distribution:
|
|
||||||
kwargs_num_classes = sum(prototype_distribution)
|
|
||||||
else:
|
|
||||||
if "num_classes" not in kwargs:
|
|
||||||
raise NameError("`prototype_distribution` required if "
|
|
||||||
"both `data` and `num_classes` are not "
|
|
||||||
"provided.")
|
|
||||||
kwargs_num_classes = kwargs.pop("num_classes")
|
|
||||||
input_dim = kwargs.pop("input_dim")
|
|
||||||
if prototype_initializer in [
|
|
||||||
"stratified_mean", "stratified_random"
|
|
||||||
]:
|
|
||||||
warnings.warn(
|
|
||||||
f"`prototype_initializer`: `{prototype_initializer}` "
|
|
||||||
"requires `data`, but `data` is not provided. "
|
|
||||||
"Using randomly generated data instead.")
|
|
||||||
x_train = torch.rand(kwargs_num_classes, input_dim)
|
|
||||||
y_train = torch.arange(kwargs_num_classes)
|
|
||||||
if one_hot_labels:
|
|
||||||
y_train = torch.eye(kwargs_num_classes)[y_train]
|
|
||||||
data = [x_train, y_train]
|
|
||||||
|
|
||||||
x_train, y_train = data
|
|
||||||
x_train = torch.as_tensor(x_train).type(dtype)
|
|
||||||
y_train = torch.as_tensor(y_train).type(torch.int)
|
|
||||||
num_classes = torch.unique(y_train, dim=-1).shape[-1]
|
|
||||||
|
|
||||||
if num_classes == 1:
|
|
||||||
warnings.warn("Are you sure about having one class only?")
|
|
||||||
|
|
||||||
if x_train.ndim != 2:
|
|
||||||
raise ValueError("`data[0].ndim != 2`.")
|
|
||||||
|
|
||||||
if y_train.ndim == 2:
|
|
||||||
if y_train.shape[1] == 1 and one_hot_labels:
|
|
||||||
raise ValueError("`one_hot_labels` is set to `True` "
|
|
||||||
"but target labels are not one-hot-encoded.")
|
|
||||||
if y_train.shape[1] != 1 and not one_hot_labels:
|
|
||||||
raise ValueError("`one_hot_labels` is set to `False` "
|
|
||||||
"but target labels in `data` "
|
|
||||||
"are one-hot-encoded.")
|
|
||||||
if y_train.ndim == 1 and one_hot_labels:
|
|
||||||
raise ValueError("`one_hot_labels` is set to `True` "
|
|
||||||
"but target labels are not one-hot-encoded.")
|
|
||||||
|
|
||||||
# Verify input dimension if `input_dim` is provided
|
|
||||||
if "input_dim" in kwargs:
|
|
||||||
input_dim = kwargs.pop("input_dim")
|
|
||||||
if input_dim != x_train.shape[1]:
|
|
||||||
raise ValueError(f"Provided `input_dim`={input_dim} does "
|
|
||||||
"not match data dimension "
|
|
||||||
f"`data[0].shape[1]`={x_train.shape[1]}")
|
|
||||||
|
|
||||||
# Verify the number of classes if `num_classes` is provided
|
|
||||||
if "num_classes" in kwargs:
|
|
||||||
kwargs_num_classes = kwargs.pop("num_classes")
|
|
||||||
if kwargs_num_classes != num_classes:
|
|
||||||
raise ValueError(
|
|
||||||
f"Provided `num_classes={kwargs_num_classes}` does "
|
|
||||||
"not match data labels "
|
|
||||||
"`torch.unique(data[1]).shape[0]`"
|
|
||||||
f"={num_classes}")
|
|
||||||
|
|
||||||
super().__init__(**kwargs)
|
|
||||||
|
|
||||||
if not prototype_distribution:
|
|
||||||
prototype_distribution = [prototypes_per_class] * num_classes
|
|
||||||
with torch.no_grad():
|
|
||||||
self.prototype_distribution = torch.tensor(prototype_distribution)
|
|
||||||
|
|
||||||
self._validate_prototype_distribution()
|
|
||||||
|
|
||||||
self.prototype_initializer = get_initializer(prototype_initializer)
|
|
||||||
prototypes, prototype_labels = self.prototype_initializer(
|
|
||||||
x_train,
|
|
||||||
y_train,
|
|
||||||
prototype_distribution=self.prototype_distribution,
|
|
||||||
one_hot=one_hot_labels,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Register module parameters
|
|
||||||
self.prototypes = torch.nn.Parameter(prototypes)
|
|
||||||
self.prototype_labels = torch.nn.Parameter(
|
|
||||||
prototype_labels.type(dtype)).requires_grad_(False)
|
|
@ -1,297 +0,0 @@
|
|||||||
"""ProtoTorch modules test suite."""
|
|
||||||
|
|
||||||
import unittest
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
from prototorch.modules import losses, prototypes
|
|
||||||
|
|
||||||
|
|
||||||
class TestPrototypes(unittest.TestCase):
|
|
||||||
def setUp(self):
|
|
||||||
self.x = torch.tensor(
|
|
||||||
[[0, -1, -2], [10, 11, 12], [0, 0, 0], [2, 2, 2]],
|
|
||||||
dtype=torch.float32)
|
|
||||||
self.y = torch.tensor([0, 0, 1, 1])
|
|
||||||
self.gen = torch.manual_seed(42)
|
|
||||||
|
|
||||||
def test_prototypes1d_init_without_input_dim(self):
|
|
||||||
with self.assertRaises(NameError):
|
|
||||||
_ = prototypes.Prototypes1D(num_classes=2)
|
|
||||||
|
|
||||||
def test_prototypes1d_init_without_num_classes(self):
|
|
||||||
with self.assertRaises(NameError):
|
|
||||||
_ = prototypes.Prototypes1D(input_dim=1)
|
|
||||||
|
|
||||||
def test_prototypes1d_init_with_num_classes_1(self):
|
|
||||||
with self.assertWarns(UserWarning):
|
|
||||||
_ = prototypes.Prototypes1D(num_classes=1, input_dim=1)
|
|
||||||
|
|
||||||
def test_prototypes1d_init_without_pdist(self):
|
|
||||||
p1 = prototypes.Prototypes1D(
|
|
||||||
input_dim=6,
|
|
||||||
num_classes=2,
|
|
||||||
prototypes_per_class=4,
|
|
||||||
prototype_initializer="ones",
|
|
||||||
)
|
|
||||||
protos = p1.prototypes
|
|
||||||
actual = protos.detach().numpy()
|
|
||||||
desired = torch.ones(8, 6)
|
|
||||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
|
||||||
desired,
|
|
||||||
decimal=5)
|
|
||||||
self.assertIsNone(mismatch)
|
|
||||||
|
|
||||||
def test_prototypes1d_init_without_data(self):
|
|
||||||
pdist = [2, 2]
|
|
||||||
p1 = prototypes.Prototypes1D(input_dim=3,
|
|
||||||
prototype_distribution=pdist,
|
|
||||||
prototype_initializer="zeros")
|
|
||||||
protos = p1.prototypes
|
|
||||||
actual = protos.detach().numpy()
|
|
||||||
desired = torch.zeros(4, 3)
|
|
||||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
|
||||||
desired,
|
|
||||||
decimal=5)
|
|
||||||
self.assertIsNone(mismatch)
|
|
||||||
|
|
||||||
def test_prototypes1d_proto_init_without_data(self):
|
|
||||||
with self.assertWarns(UserWarning):
|
|
||||||
_ = prototypes.Prototypes1D(
|
|
||||||
input_dim=3,
|
|
||||||
num_classes=2,
|
|
||||||
prototypes_per_class=1,
|
|
||||||
prototype_initializer="stratified_mean",
|
|
||||||
data=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_prototypes1d_init_torch_pdist(self):
|
|
||||||
pdist = torch.tensor([2, 2])
|
|
||||||
p1 = prototypes.Prototypes1D(input_dim=3,
|
|
||||||
prototype_distribution=pdist,
|
|
||||||
prototype_initializer="zeros")
|
|
||||||
protos = p1.prototypes
|
|
||||||
actual = protos.detach().numpy()
|
|
||||||
desired = torch.zeros(4, 3)
|
|
||||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
|
||||||
desired,
|
|
||||||
decimal=5)
|
|
||||||
self.assertIsNone(mismatch)
|
|
||||||
|
|
||||||
def test_prototypes1d_init_without_inputdim_with_data(self):
|
|
||||||
_ = prototypes.Prototypes1D(
|
|
||||||
num_classes=2,
|
|
||||||
prototypes_per_class=1,
|
|
||||||
prototype_initializer="stratified_mean",
|
|
||||||
data=[[[1.0], [0.0]], [1, 0]],
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_prototypes1d_init_with_int_data(self):
|
|
||||||
_ = prototypes.Prototypes1D(
|
|
||||||
num_classes=2,
|
|
||||||
prototypes_per_class=1,
|
|
||||||
prototype_initializer="stratified_mean",
|
|
||||||
data=[[[1], [0]], [1, 0]],
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_prototypes1d_init_one_hot_without_data(self):
|
|
||||||
_ = prototypes.Prototypes1D(
|
|
||||||
input_dim=1,
|
|
||||||
num_classes=2,
|
|
||||||
prototypes_per_class=1,
|
|
||||||
prototype_initializer="stratified_mean",
|
|
||||||
data=None,
|
|
||||||
one_hot_labels=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_prototypes1d_init_one_hot_labels_false(self):
|
|
||||||
"""Test if ValueError is raised when `one_hot_labels` is set to `False`
|
|
||||||
but the provided `data` has one-hot encoded labels.
|
|
||||||
"""
|
|
||||||
with self.assertRaises(ValueError):
|
|
||||||
_ = prototypes.Prototypes1D(
|
|
||||||
input_dim=1,
|
|
||||||
num_classes=2,
|
|
||||||
prototypes_per_class=1,
|
|
||||||
prototype_initializer="stratified_mean",
|
|
||||||
data=([[0.0], [1.0]], [[0, 1], [1, 0]]),
|
|
||||||
one_hot_labels=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_prototypes1d_init_1d_y_data_one_hot_labels_true(self):
|
|
||||||
"""Test if ValueError is raised when `one_hot_labels` is set to `True`
|
|
||||||
but the provided `data` does not contain one-hot encoded labels.
|
|
||||||
"""
|
|
||||||
with self.assertRaises(ValueError):
|
|
||||||
_ = prototypes.Prototypes1D(
|
|
||||||
input_dim=1,
|
|
||||||
num_classes=2,
|
|
||||||
prototypes_per_class=1,
|
|
||||||
prototype_initializer="stratified_mean",
|
|
||||||
data=([[0.0], [1.0]], [0, 1]),
|
|
||||||
one_hot_labels=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_prototypes1d_init_one_hot_labels_true(self):
|
|
||||||
"""Test if ValueError is raised when `one_hot_labels` is set to `True`
|
|
||||||
but the provided `data` contains 2D targets but
|
|
||||||
does not contain one-hot encoded labels.
|
|
||||||
"""
|
|
||||||
with self.assertRaises(ValueError):
|
|
||||||
_ = prototypes.Prototypes1D(
|
|
||||||
input_dim=1,
|
|
||||||
num_classes=2,
|
|
||||||
prototypes_per_class=1,
|
|
||||||
prototype_initializer="stratified_mean",
|
|
||||||
data=([[0.0], [1.0]], [[0], [1]]),
|
|
||||||
one_hot_labels=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_prototypes1d_init_with_int_dtype(self):
|
|
||||||
with self.assertRaises(RuntimeError):
|
|
||||||
_ = prototypes.Prototypes1D(
|
|
||||||
num_classes=2,
|
|
||||||
prototypes_per_class=1,
|
|
||||||
prototype_initializer="stratified_mean",
|
|
||||||
data=[[[1], [0]], [1, 0]],
|
|
||||||
dtype=torch.int32,
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_prototypes1d_inputndim_with_data(self):
|
|
||||||
with self.assertRaises(ValueError):
|
|
||||||
_ = prototypes.Prototypes1D(input_dim=1,
|
|
||||||
num_classes=1,
|
|
||||||
prototypes_per_class=1,
|
|
||||||
data=[[1.0], [1]])
|
|
||||||
|
|
||||||
def test_prototypes1d_inputdim_with_data(self):
|
|
||||||
with self.assertRaises(ValueError):
|
|
||||||
_ = prototypes.Prototypes1D(
|
|
||||||
input_dim=2,
|
|
||||||
num_classes=2,
|
|
||||||
prototypes_per_class=1,
|
|
||||||
prototype_initializer="stratified_mean",
|
|
||||||
data=[[[1.0], [0.0]], [1, 0]],
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_prototypes1d_num_classes_with_data(self):
|
|
||||||
"""Test ValueError raise if provided `num_classes` is not the same
|
|
||||||
as the one computed from the provided `data`.
|
|
||||||
"""
|
|
||||||
with self.assertRaises(ValueError):
|
|
||||||
_ = prototypes.Prototypes1D(
|
|
||||||
input_dim=1,
|
|
||||||
num_classes=1,
|
|
||||||
prototypes_per_class=1,
|
|
||||||
prototype_initializer="stratified_mean",
|
|
||||||
data=[[[1.0], [2.0]], [1, 2]],
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_prototypes1d_init_with_ppc(self):
|
|
||||||
p1 = prototypes.Prototypes1D(data=[self.x, self.y],
|
|
||||||
prototypes_per_class=2,
|
|
||||||
prototype_initializer="zeros")
|
|
||||||
protos = p1.prototypes
|
|
||||||
actual = protos.detach().numpy()
|
|
||||||
desired = torch.zeros(4, 3)
|
|
||||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
|
||||||
desired,
|
|
||||||
decimal=5)
|
|
||||||
self.assertIsNone(mismatch)
|
|
||||||
|
|
||||||
def test_prototypes1d_init_with_pdist(self):
|
|
||||||
p1 = prototypes.Prototypes1D(
|
|
||||||
data=[self.x, self.y],
|
|
||||||
prototype_distribution=[6, 9],
|
|
||||||
prototype_initializer="zeros",
|
|
||||||
)
|
|
||||||
protos = p1.prototypes
|
|
||||||
actual = protos.detach().numpy()
|
|
||||||
desired = torch.zeros(15, 3)
|
|
||||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
|
||||||
desired,
|
|
||||||
decimal=5)
|
|
||||||
self.assertIsNone(mismatch)
|
|
||||||
|
|
||||||
def test_prototypes1d_func_initializer(self):
|
|
||||||
def my_initializer(*args, **kwargs):
|
|
||||||
return torch.full((2, 99), 99.0), torch.tensor([0, 1])
|
|
||||||
|
|
||||||
p1 = prototypes.Prototypes1D(
|
|
||||||
input_dim=99,
|
|
||||||
num_classes=2,
|
|
||||||
prototypes_per_class=1,
|
|
||||||
prototype_initializer=my_initializer,
|
|
||||||
)
|
|
||||||
protos = p1.prototypes
|
|
||||||
actual = protos.detach().numpy()
|
|
||||||
desired = 99 * torch.ones(2, 99)
|
|
||||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
|
||||||
desired,
|
|
||||||
decimal=5)
|
|
||||||
self.assertIsNone(mismatch)
|
|
||||||
|
|
||||||
def test_prototypes1d_forward(self):
|
|
||||||
p1 = prototypes.Prototypes1D(data=[self.x, self.y])
|
|
||||||
protos, _ = p1()
|
|
||||||
actual = protos.detach().numpy()
|
|
||||||
desired = torch.ones(2, 3)
|
|
||||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
|
||||||
desired,
|
|
||||||
decimal=5)
|
|
||||||
self.assertIsNone(mismatch)
|
|
||||||
|
|
||||||
def test_prototypes1d_dist_validate(self):
|
|
||||||
p1 = prototypes.Prototypes1D(input_dim=0, prototype_distribution=[0])
|
|
||||||
with self.assertWarns(UserWarning):
|
|
||||||
_ = p1._validate_prototype_distribution()
|
|
||||||
|
|
||||||
def test_prototypes1d_validate_extra_repr_not_empty(self):
|
|
||||||
p1 = prototypes.Prototypes1D(input_dim=0, prototype_distribution=[0])
|
|
||||||
rep = p1.extra_repr()
|
|
||||||
self.assertNotEqual(rep, "")
|
|
||||||
|
|
||||||
def tearDown(self):
|
|
||||||
del self.x, self.y, self.gen
|
|
||||||
_ = torch.seed()
|
|
||||||
|
|
||||||
|
|
||||||
class TestLosses(unittest.TestCase):
|
|
||||||
def setUp(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def test_glvqloss_init(self):
|
|
||||||
_ = losses.GLVQLoss(0, "swish_beta", beta=20)
|
|
||||||
|
|
||||||
def test_glvqloss_forward_1ppc(self):
|
|
||||||
criterion = losses.GLVQLoss(margin=0,
|
|
||||||
squashing="sigmoid_beta",
|
|
||||||
beta=100)
|
|
||||||
d = torch.stack([torch.ones(100), torch.zeros(100)], dim=1)
|
|
||||||
labels = torch.tensor([0, 1])
|
|
||||||
targets = torch.ones(100)
|
|
||||||
outputs = [d, labels]
|
|
||||||
loss = criterion(outputs, targets)
|
|
||||||
loss_value = loss.item()
|
|
||||||
self.assertAlmostEqual(loss_value, 0.0)
|
|
||||||
|
|
||||||
def test_glvqloss_forward_2ppc(self):
|
|
||||||
criterion = losses.GLVQLoss(margin=0,
|
|
||||||
squashing="sigmoid_beta",
|
|
||||||
beta=100)
|
|
||||||
d = torch.stack([
|
|
||||||
torch.ones(100),
|
|
||||||
torch.ones(100),
|
|
||||||
torch.zeros(100),
|
|
||||||
torch.ones(100)
|
|
||||||
],
|
|
||||||
dim=1)
|
|
||||||
labels = torch.tensor([0, 0, 1, 1])
|
|
||||||
targets = torch.ones(100)
|
|
||||||
outputs = [d, labels]
|
|
||||||
loss = criterion(outputs, targets)
|
|
||||||
loss_value = loss.item()
|
|
||||||
self.assertAlmostEqual(loss_value, 0.0)
|
|
||||||
|
|
||||||
def tearDown(self):
|
|
||||||
pass
|
|
Loading…
Reference in New Issue
Block a user