Add small API changes and more test cases
This commit is contained in:
parent
da3b0cc262
commit
1ec7bd261b
@ -1 +1,12 @@
|
|||||||
|
"""ProtoTorch package."""
|
||||||
|
|
||||||
__version__ = '0.1.1-dev0'
|
__version__ = '0.1.1-dev0'
|
||||||
|
|
||||||
|
from prototorch import datasets, functions, modules, utils
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
'datasets',
|
||||||
|
'functions',
|
||||||
|
'modules',
|
||||||
|
'utils',
|
||||||
|
]
|
||||||
|
@ -1,23 +1,30 @@
|
|||||||
"""ProtoTorch prototype modules."""
|
"""ProtoTorch prototype modules."""
|
||||||
|
|
||||||
|
import warnings
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from prototorch.functions.initializers import get_initializer
|
from prototorch.functions.initializers import get_initializer
|
||||||
|
|
||||||
|
|
||||||
class AddPrototypes1D(torch.nn.Module):
|
class Prototypes1D(torch.nn.Module):
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
prototypes_per_class=1,
|
prototypes_per_class=1,
|
||||||
prototype_distribution=None,
|
prototype_distribution=None,
|
||||||
prototype_initializer='ones',
|
prototype_initializer='ones',
|
||||||
data=None,
|
data=None,
|
||||||
|
dtype=torch.float32,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
|
|
||||||
|
# Accept PyTorch tensors, but convert to python lists before processing
|
||||||
|
if torch.is_tensor(prototype_distribution):
|
||||||
|
prototype_distribution = prototype_distribution.tolist()
|
||||||
|
|
||||||
if data is None:
|
if data is None:
|
||||||
if 'input_dim' not in kwargs:
|
if 'input_dim' not in kwargs:
|
||||||
raise NameError('`input_dim` required if '
|
raise NameError('`input_dim` required if '
|
||||||
'no `data` is provided.')
|
'no `data` is provided.')
|
||||||
if prototype_distribution is not None:
|
if prototype_distribution:
|
||||||
nclasses = sum(prototype_distribution)
|
nclasses = sum(prototype_distribution)
|
||||||
else:
|
else:
|
||||||
if 'nclasses' not in kwargs:
|
if 'nclasses' not in kwargs:
|
||||||
@ -26,30 +33,46 @@ class AddPrototypes1D(torch.nn.Module):
|
|||||||
'provided.')
|
'provided.')
|
||||||
nclasses = kwargs.pop('nclasses')
|
nclasses = kwargs.pop('nclasses')
|
||||||
input_dim = kwargs.pop('input_dim')
|
input_dim = kwargs.pop('input_dim')
|
||||||
# input_shape = (input_dim, )
|
if prototype_initializer in [
|
||||||
|
'stratified_mean', 'stratified_random'
|
||||||
|
]:
|
||||||
|
warnings.warn(
|
||||||
|
f'`prototype_initializer`: `{prototype_initializer}` '
|
||||||
|
'requires `data`, but `data` is not provided. '
|
||||||
|
'Using randomly generated data instead.')
|
||||||
x_train = torch.rand(nclasses, input_dim)
|
x_train = torch.rand(nclasses, input_dim)
|
||||||
y_train = torch.arange(nclasses)
|
y_train = torch.arange(nclasses)
|
||||||
|
data = [x_train, y_train]
|
||||||
|
|
||||||
else:
|
|
||||||
x_train, y_train = data
|
x_train, y_train = data
|
||||||
x_train = torch.as_tensor(x_train)
|
x_train = torch.as_tensor(x_train).type(dtype)
|
||||||
y_train = torch.as_tensor(y_train)
|
y_train = torch.as_tensor(y_train).type(dtype)
|
||||||
|
nclasses = torch.unique(y_train).shape[0]
|
||||||
|
|
||||||
|
assert x_train.ndim == 2
|
||||||
|
|
||||||
|
# Verify input dimension if `input_dim` is provided
|
||||||
|
if 'input_dim' in kwargs:
|
||||||
|
assert kwargs.pop('input_dim') == x_train.shape[1]
|
||||||
|
|
||||||
|
# Verify the number of classes if `nclasses` is provided
|
||||||
|
if 'nclasses' in kwargs:
|
||||||
|
assert nclasses == kwargs.pop('nclasses')
|
||||||
|
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self.prototypes_per_class = prototypes_per_class
|
|
||||||
with torch.no_grad():
|
|
||||||
if not prototype_distribution:
|
if not prototype_distribution:
|
||||||
num_classes = torch.unique(y_train).shape[0]
|
prototype_distribution = [prototypes_per_class] * nclasses
|
||||||
self.prototype_distribution = torch.tensor(
|
with torch.no_grad():
|
||||||
[self.prototypes_per_class] * num_classes)
|
self.prototype_distribution = torch.tensor(prototype_distribution)
|
||||||
else:
|
|
||||||
self.prototype_distribution = torch.tensor(
|
|
||||||
prototype_distribution)
|
|
||||||
self.prototype_initializer = get_initializer(prototype_initializer)
|
self.prototype_initializer = get_initializer(prototype_initializer)
|
||||||
prototypes, prototype_labels = self.prototype_initializer(
|
prototypes, prototype_labels = self.prototype_initializer(
|
||||||
x_train,
|
x_train,
|
||||||
y_train,
|
y_train,
|
||||||
prototype_distribution=self.prototype_distribution)
|
prototype_distribution=self.prototype_distribution)
|
||||||
|
|
||||||
|
# Register module parameters
|
||||||
self.prototypes = torch.nn.Parameter(prototypes)
|
self.prototypes = torch.nn.Parameter(prototypes)
|
||||||
self.prototype_labels = prototype_labels
|
self.prototype_labels = prototype_labels
|
||||||
|
|
||||||
|
@ -16,16 +16,16 @@ class TestPrototypes(unittest.TestCase):
|
|||||||
self.y = torch.tensor([0, 0, 1, 1])
|
self.y = torch.tensor([0, 0, 1, 1])
|
||||||
self.gen = torch.manual_seed(42)
|
self.gen = torch.manual_seed(42)
|
||||||
|
|
||||||
def test_addprototypes1d_init_without_input_dim(self):
|
def test_prototypes1d_init_without_input_dim(self):
|
||||||
with self.assertRaises(NameError):
|
with self.assertRaises(NameError):
|
||||||
_ = prototypes.AddPrototypes1D(nclasses=1)
|
_ = prototypes.Prototypes1D(nclasses=1)
|
||||||
|
|
||||||
def test_addprototypes1d_init_without_nclasses(self):
|
def test_prototypes1d_init_without_nclasses(self):
|
||||||
with self.assertRaises(NameError):
|
with self.assertRaises(NameError):
|
||||||
_ = prototypes.AddPrototypes1D(input_dim=1)
|
_ = prototypes.Prototypes1D(input_dim=1)
|
||||||
|
|
||||||
def test_addprototypes1d_init_without_pdist(self):
|
def test_prototypes1d_init_without_pdist(self):
|
||||||
p1 = prototypes.AddPrototypes1D(input_dim=6,
|
p1 = prototypes.Prototypes1D(input_dim=6,
|
||||||
nclasses=2,
|
nclasses=2,
|
||||||
prototypes_per_class=4,
|
prototypes_per_class=4,
|
||||||
prototype_initializer='ones')
|
prototype_initializer='ones')
|
||||||
@ -37,9 +37,9 @@ class TestPrototypes(unittest.TestCase):
|
|||||||
decimal=5)
|
decimal=5)
|
||||||
self.assertIsNone(mismatch)
|
self.assertIsNone(mismatch)
|
||||||
|
|
||||||
def test_addprototypes1d_init_without_data(self):
|
def test_prototypes1d_init_without_data(self):
|
||||||
pdist = [2, 2]
|
pdist = [2, 2]
|
||||||
p1 = prototypes.AddPrototypes1D(input_dim=3,
|
p1 = prototypes.Prototypes1D(input_dim=3,
|
||||||
prototype_distribution=pdist,
|
prototype_distribution=pdist,
|
||||||
prototype_initializer='zeros')
|
prototype_initializer='zeros')
|
||||||
protos = p1.prototypes
|
protos = p1.prototypes
|
||||||
@ -50,21 +50,76 @@ class TestPrototypes(unittest.TestCase):
|
|||||||
decimal=5)
|
decimal=5)
|
||||||
self.assertIsNone(mismatch)
|
self.assertIsNone(mismatch)
|
||||||
|
|
||||||
# def test_addprototypes1d_init_torch_pdist(self):
|
def test_prototypes1d_proto_init_without_data(self):
|
||||||
# pdist = torch.tensor([2, 2])
|
with self.assertWarns(Warning):
|
||||||
# p1 = prototypes.AddPrototypes1D(input_dim=3,
|
_ = prototypes.Prototypes1D(
|
||||||
# prototype_distribution=pdist,
|
input_dim=3,
|
||||||
# prototype_initializer='zeros')
|
nclasses=2,
|
||||||
# protos = p1.prototypes
|
prototypes_per_class=1,
|
||||||
# actual = protos.detach().numpy()
|
prototype_initializer='stratified_mean',
|
||||||
# desired = torch.zeros(4, 3)
|
data=None)
|
||||||
# mismatch = np.testing.assert_array_almost_equal(actual,
|
|
||||||
# desired,
|
|
||||||
# decimal=5)
|
|
||||||
# self.assertIsNone(mismatch)
|
|
||||||
|
|
||||||
def test_addprototypes1d_init_with_ppc(self):
|
def test_prototypes1d_init_torch_pdist(self):
|
||||||
p1 = prototypes.AddPrototypes1D(data=[self.x, self.y],
|
pdist = torch.tensor([2, 2])
|
||||||
|
p1 = prototypes.Prototypes1D(input_dim=3,
|
||||||
|
prototype_distribution=pdist,
|
||||||
|
prototype_initializer='zeros')
|
||||||
|
protos = p1.prototypes
|
||||||
|
actual = protos.detach().numpy()
|
||||||
|
desired = torch.zeros(4, 3)
|
||||||
|
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||||
|
desired,
|
||||||
|
decimal=5)
|
||||||
|
self.assertIsNone(mismatch)
|
||||||
|
|
||||||
|
def test_prototypes1d_init_without_inputdim_with_data(self):
|
||||||
|
_ = prototypes.Prototypes1D(nclasses=1,
|
||||||
|
prototypes_per_class=1,
|
||||||
|
prototype_initializer='stratified_mean',
|
||||||
|
data=[[[1.]], [1]])
|
||||||
|
|
||||||
|
def test_prototypes1d_init_with_int_data(self):
|
||||||
|
_ = prototypes.Prototypes1D(nclasses=1,
|
||||||
|
prototypes_per_class=1,
|
||||||
|
prototype_initializer='stratified_mean',
|
||||||
|
data=[[[1]], [1]])
|
||||||
|
|
||||||
|
def test_prototypes1d_init_with_int_dtype(self):
|
||||||
|
with self.assertRaises(RuntimeError):
|
||||||
|
_ = prototypes.Prototypes1D(
|
||||||
|
nclasses=1,
|
||||||
|
prototypes_per_class=1,
|
||||||
|
prototype_initializer='stratified_mean',
|
||||||
|
data=[[[1]], [1]],
|
||||||
|
dtype=torch.int32)
|
||||||
|
|
||||||
|
def test_prototypes1d_inputndim_with_data(self):
|
||||||
|
with self.assertRaises(AssertionError):
|
||||||
|
_ = prototypes.Prototypes1D(input_dim=1,
|
||||||
|
nclasses=1,
|
||||||
|
prototypes_per_class=1,
|
||||||
|
data=[[1.], [1]])
|
||||||
|
|
||||||
|
def test_prototypes1d_inputdim_with_data(self):
|
||||||
|
with self.assertRaises(AssertionError):
|
||||||
|
_ = prototypes.Prototypes1D(
|
||||||
|
input_dim=2,
|
||||||
|
nclasses=1,
|
||||||
|
prototypes_per_class=1,
|
||||||
|
prototype_initializer='stratified_mean',
|
||||||
|
data=[[[1.]], [1]])
|
||||||
|
|
||||||
|
def test_prototypes1d_nclasses_with_data(self):
|
||||||
|
with self.assertRaises(AssertionError):
|
||||||
|
_ = prototypes.Prototypes1D(
|
||||||
|
input_dim=1,
|
||||||
|
nclasses=1,
|
||||||
|
prototypes_per_class=1,
|
||||||
|
prototype_initializer='stratified_mean',
|
||||||
|
data=[[[1.], [2.]], [1, 2]])
|
||||||
|
|
||||||
|
def test_prototypes1d_init_with_ppc(self):
|
||||||
|
p1 = prototypes.Prototypes1D(data=[self.x, self.y],
|
||||||
prototypes_per_class=2,
|
prototypes_per_class=2,
|
||||||
prototype_initializer='zeros')
|
prototype_initializer='zeros')
|
||||||
protos = p1.prototypes
|
protos = p1.prototypes
|
||||||
@ -75,8 +130,8 @@ class TestPrototypes(unittest.TestCase):
|
|||||||
decimal=5)
|
decimal=5)
|
||||||
self.assertIsNone(mismatch)
|
self.assertIsNone(mismatch)
|
||||||
|
|
||||||
def test_addprototypes1d_init_with_pdist(self):
|
def test_prototypes1d_init_with_pdist(self):
|
||||||
p1 = prototypes.AddPrototypes1D(data=[self.x, self.y],
|
p1 = prototypes.Prototypes1D(data=[self.x, self.y],
|
||||||
prototype_distribution=[6, 9],
|
prototype_distribution=[6, 9],
|
||||||
prototype_initializer='zeros')
|
prototype_initializer='zeros')
|
||||||
protos = p1.prototypes
|
protos = p1.prototypes
|
||||||
@ -87,11 +142,11 @@ class TestPrototypes(unittest.TestCase):
|
|||||||
decimal=5)
|
decimal=5)
|
||||||
self.assertIsNone(mismatch)
|
self.assertIsNone(mismatch)
|
||||||
|
|
||||||
def test_addprototypes1d_func_initializer(self):
|
def test_prototypes1d_func_initializer(self):
|
||||||
def my_initializer(*args, **kwargs):
|
def my_initializer(*args, **kwargs):
|
||||||
return torch.full((2, 99), 99), torch.tensor([0, 1])
|
return torch.full((2, 99), 99), torch.tensor([0, 1])
|
||||||
|
|
||||||
p1 = prototypes.AddPrototypes1D(input_dim=99,
|
p1 = prototypes.Prototypes1D(input_dim=99,
|
||||||
nclasses=2,
|
nclasses=2,
|
||||||
prototypes_per_class=1,
|
prototypes_per_class=1,
|
||||||
prototype_initializer=my_initializer)
|
prototype_initializer=my_initializer)
|
||||||
@ -103,8 +158,8 @@ class TestPrototypes(unittest.TestCase):
|
|||||||
decimal=5)
|
decimal=5)
|
||||||
self.assertIsNone(mismatch)
|
self.assertIsNone(mismatch)
|
||||||
|
|
||||||
def test_addprototypes1d_forward(self):
|
def test_prototypes1d_forward(self):
|
||||||
p1 = prototypes.AddPrototypes1D(data=[self.x, self.y])
|
p1 = prototypes.Prototypes1D(data=[self.x, self.y])
|
||||||
protos, _ = p1()
|
protos, _ = p1()
|
||||||
actual = protos.detach().numpy()
|
actual = protos.detach().numpy()
|
||||||
desired = torch.ones(2, 3)
|
desired = torch.ones(2, 3)
|
||||||
|
Loading…
Reference in New Issue
Block a user