Add small API changes and more test cases

This commit is contained in:
blackfly 2020-04-11 14:28:22 +02:00
parent da3b0cc262
commit 1ec7bd261b
3 changed files with 144 additions and 55 deletions

View File

@ -1 +1,12 @@
"""ProtoTorch package."""
__version__ = '0.1.1-dev0' __version__ = '0.1.1-dev0'
from prototorch import datasets, functions, modules, utils
__all__ = [
'datasets',
'functions',
'modules',
'utils',
]

View File

@ -1,23 +1,30 @@
"""ProtoTorch prototype modules.""" """ProtoTorch prototype modules."""
import warnings
import torch import torch
from prototorch.functions.initializers import get_initializer from prototorch.functions.initializers import get_initializer
class AddPrototypes1D(torch.nn.Module): class Prototypes1D(torch.nn.Module):
def __init__(self, def __init__(self,
prototypes_per_class=1, prototypes_per_class=1,
prototype_distribution=None, prototype_distribution=None,
prototype_initializer='ones', prototype_initializer='ones',
data=None, data=None,
dtype=torch.float32,
**kwargs): **kwargs):
# Accept PyTorch tensors, but convert to python lists before processing
if torch.is_tensor(prototype_distribution):
prototype_distribution = prototype_distribution.tolist()
if data is None: if data is None:
if 'input_dim' not in kwargs: if 'input_dim' not in kwargs:
raise NameError('`input_dim` required if ' raise NameError('`input_dim` required if '
'no `data` is provided.') 'no `data` is provided.')
if prototype_distribution is not None: if prototype_distribution:
nclasses = sum(prototype_distribution) nclasses = sum(prototype_distribution)
else: else:
if 'nclasses' not in kwargs: if 'nclasses' not in kwargs:
@ -26,30 +33,46 @@ class AddPrototypes1D(torch.nn.Module):
'provided.') 'provided.')
nclasses = kwargs.pop('nclasses') nclasses = kwargs.pop('nclasses')
input_dim = kwargs.pop('input_dim') input_dim = kwargs.pop('input_dim')
# input_shape = (input_dim, ) if prototype_initializer in [
'stratified_mean', 'stratified_random'
]:
warnings.warn(
f'`prototype_initializer`: `{prototype_initializer}` '
'requires `data`, but `data` is not provided. '
'Using randomly generated data instead.')
x_train = torch.rand(nclasses, input_dim) x_train = torch.rand(nclasses, input_dim)
y_train = torch.arange(nclasses) y_train = torch.arange(nclasses)
data = [x_train, y_train]
else:
x_train, y_train = data x_train, y_train = data
x_train = torch.as_tensor(x_train) x_train = torch.as_tensor(x_train).type(dtype)
y_train = torch.as_tensor(y_train) y_train = torch.as_tensor(y_train).type(dtype)
nclasses = torch.unique(y_train).shape[0]
assert x_train.ndim == 2
# Verify input dimension if `input_dim` is provided
if 'input_dim' in kwargs:
assert kwargs.pop('input_dim') == x_train.shape[1]
# Verify the number of classes if `nclasses` is provided
if 'nclasses' in kwargs:
assert nclasses == kwargs.pop('nclasses')
super().__init__(**kwargs) super().__init__(**kwargs)
self.prototypes_per_class = prototypes_per_class
with torch.no_grad():
if not prototype_distribution: if not prototype_distribution:
num_classes = torch.unique(y_train).shape[0] prototype_distribution = [prototypes_per_class] * nclasses
self.prototype_distribution = torch.tensor( with torch.no_grad():
[self.prototypes_per_class] * num_classes) self.prototype_distribution = torch.tensor(prototype_distribution)
else:
self.prototype_distribution = torch.tensor(
prototype_distribution)
self.prototype_initializer = get_initializer(prototype_initializer) self.prototype_initializer = get_initializer(prototype_initializer)
prototypes, prototype_labels = self.prototype_initializer( prototypes, prototype_labels = self.prototype_initializer(
x_train, x_train,
y_train, y_train,
prototype_distribution=self.prototype_distribution) prototype_distribution=self.prototype_distribution)
# Register module parameters
self.prototypes = torch.nn.Parameter(prototypes) self.prototypes = torch.nn.Parameter(prototypes)
self.prototype_labels = prototype_labels self.prototype_labels = prototype_labels

View File

@ -16,16 +16,16 @@ class TestPrototypes(unittest.TestCase):
self.y = torch.tensor([0, 0, 1, 1]) self.y = torch.tensor([0, 0, 1, 1])
self.gen = torch.manual_seed(42) self.gen = torch.manual_seed(42)
def test_addprototypes1d_init_without_input_dim(self): def test_prototypes1d_init_without_input_dim(self):
with self.assertRaises(NameError): with self.assertRaises(NameError):
_ = prototypes.AddPrototypes1D(nclasses=1) _ = prototypes.Prototypes1D(nclasses=1)
def test_addprototypes1d_init_without_nclasses(self): def test_prototypes1d_init_without_nclasses(self):
with self.assertRaises(NameError): with self.assertRaises(NameError):
_ = prototypes.AddPrototypes1D(input_dim=1) _ = prototypes.Prototypes1D(input_dim=1)
def test_addprototypes1d_init_without_pdist(self): def test_prototypes1d_init_without_pdist(self):
p1 = prototypes.AddPrototypes1D(input_dim=6, p1 = prototypes.Prototypes1D(input_dim=6,
nclasses=2, nclasses=2,
prototypes_per_class=4, prototypes_per_class=4,
prototype_initializer='ones') prototype_initializer='ones')
@ -37,9 +37,9 @@ class TestPrototypes(unittest.TestCase):
decimal=5) decimal=5)
self.assertIsNone(mismatch) self.assertIsNone(mismatch)
def test_addprototypes1d_init_without_data(self): def test_prototypes1d_init_without_data(self):
pdist = [2, 2] pdist = [2, 2]
p1 = prototypes.AddPrototypes1D(input_dim=3, p1 = prototypes.Prototypes1D(input_dim=3,
prototype_distribution=pdist, prototype_distribution=pdist,
prototype_initializer='zeros') prototype_initializer='zeros')
protos = p1.prototypes protos = p1.prototypes
@ -50,21 +50,76 @@ class TestPrototypes(unittest.TestCase):
decimal=5) decimal=5)
self.assertIsNone(mismatch) self.assertIsNone(mismatch)
# def test_addprototypes1d_init_torch_pdist(self): def test_prototypes1d_proto_init_without_data(self):
# pdist = torch.tensor([2, 2]) with self.assertWarns(Warning):
# p1 = prototypes.AddPrototypes1D(input_dim=3, _ = prototypes.Prototypes1D(
# prototype_distribution=pdist, input_dim=3,
# prototype_initializer='zeros') nclasses=2,
# protos = p1.prototypes prototypes_per_class=1,
# actual = protos.detach().numpy() prototype_initializer='stratified_mean',
# desired = torch.zeros(4, 3) data=None)
# mismatch = np.testing.assert_array_almost_equal(actual,
# desired,
# decimal=5)
# self.assertIsNone(mismatch)
def test_addprototypes1d_init_with_ppc(self): def test_prototypes1d_init_torch_pdist(self):
p1 = prototypes.AddPrototypes1D(data=[self.x, self.y], pdist = torch.tensor([2, 2])
p1 = prototypes.Prototypes1D(input_dim=3,
prototype_distribution=pdist,
prototype_initializer='zeros')
protos = p1.prototypes
actual = protos.detach().numpy()
desired = torch.zeros(4, 3)
mismatch = np.testing.assert_array_almost_equal(actual,
desired,
decimal=5)
self.assertIsNone(mismatch)
def test_prototypes1d_init_without_inputdim_with_data(self):
_ = prototypes.Prototypes1D(nclasses=1,
prototypes_per_class=1,
prototype_initializer='stratified_mean',
data=[[[1.]], [1]])
def test_prototypes1d_init_with_int_data(self):
_ = prototypes.Prototypes1D(nclasses=1,
prototypes_per_class=1,
prototype_initializer='stratified_mean',
data=[[[1]], [1]])
def test_prototypes1d_init_with_int_dtype(self):
with self.assertRaises(RuntimeError):
_ = prototypes.Prototypes1D(
nclasses=1,
prototypes_per_class=1,
prototype_initializer='stratified_mean',
data=[[[1]], [1]],
dtype=torch.int32)
def test_prototypes1d_inputndim_with_data(self):
with self.assertRaises(AssertionError):
_ = prototypes.Prototypes1D(input_dim=1,
nclasses=1,
prototypes_per_class=1,
data=[[1.], [1]])
def test_prototypes1d_inputdim_with_data(self):
with self.assertRaises(AssertionError):
_ = prototypes.Prototypes1D(
input_dim=2,
nclasses=1,
prototypes_per_class=1,
prototype_initializer='stratified_mean',
data=[[[1.]], [1]])
def test_prototypes1d_nclasses_with_data(self):
with self.assertRaises(AssertionError):
_ = prototypes.Prototypes1D(
input_dim=1,
nclasses=1,
prototypes_per_class=1,
prototype_initializer='stratified_mean',
data=[[[1.], [2.]], [1, 2]])
def test_prototypes1d_init_with_ppc(self):
p1 = prototypes.Prototypes1D(data=[self.x, self.y],
prototypes_per_class=2, prototypes_per_class=2,
prototype_initializer='zeros') prototype_initializer='zeros')
protos = p1.prototypes protos = p1.prototypes
@ -75,8 +130,8 @@ class TestPrototypes(unittest.TestCase):
decimal=5) decimal=5)
self.assertIsNone(mismatch) self.assertIsNone(mismatch)
def test_addprototypes1d_init_with_pdist(self): def test_prototypes1d_init_with_pdist(self):
p1 = prototypes.AddPrototypes1D(data=[self.x, self.y], p1 = prototypes.Prototypes1D(data=[self.x, self.y],
prototype_distribution=[6, 9], prototype_distribution=[6, 9],
prototype_initializer='zeros') prototype_initializer='zeros')
protos = p1.prototypes protos = p1.prototypes
@ -87,11 +142,11 @@ class TestPrototypes(unittest.TestCase):
decimal=5) decimal=5)
self.assertIsNone(mismatch) self.assertIsNone(mismatch)
def test_addprototypes1d_func_initializer(self): def test_prototypes1d_func_initializer(self):
def my_initializer(*args, **kwargs): def my_initializer(*args, **kwargs):
return torch.full((2, 99), 99), torch.tensor([0, 1]) return torch.full((2, 99), 99), torch.tensor([0, 1])
p1 = prototypes.AddPrototypes1D(input_dim=99, p1 = prototypes.Prototypes1D(input_dim=99,
nclasses=2, nclasses=2,
prototypes_per_class=1, prototypes_per_class=1,
prototype_initializer=my_initializer) prototype_initializer=my_initializer)
@ -103,8 +158,8 @@ class TestPrototypes(unittest.TestCase):
decimal=5) decimal=5)
self.assertIsNone(mismatch) self.assertIsNone(mismatch)
def test_addprototypes1d_forward(self): def test_prototypes1d_forward(self):
p1 = prototypes.AddPrototypes1D(data=[self.x, self.y]) p1 = prototypes.Prototypes1D(data=[self.x, self.y])
protos, _ = p1() protos, _ = p1()
actual = protos.detach().numpy() actual = protos.detach().numpy()
desired = torch.ones(2, 3) desired = torch.ones(2, 3)