Add small API changes and more test cases
This commit is contained in:
@@ -16,19 +16,19 @@ class TestPrototypes(unittest.TestCase):
|
||||
self.y = torch.tensor([0, 0, 1, 1])
|
||||
self.gen = torch.manual_seed(42)
|
||||
|
||||
def test_addprototypes1d_init_without_input_dim(self):
|
||||
def test_prototypes1d_init_without_input_dim(self):
|
||||
with self.assertRaises(NameError):
|
||||
_ = prototypes.AddPrototypes1D(nclasses=1)
|
||||
_ = prototypes.Prototypes1D(nclasses=1)
|
||||
|
||||
def test_addprototypes1d_init_without_nclasses(self):
|
||||
def test_prototypes1d_init_without_nclasses(self):
|
||||
with self.assertRaises(NameError):
|
||||
_ = prototypes.AddPrototypes1D(input_dim=1)
|
||||
_ = prototypes.Prototypes1D(input_dim=1)
|
||||
|
||||
def test_addprototypes1d_init_without_pdist(self):
|
||||
p1 = prototypes.AddPrototypes1D(input_dim=6,
|
||||
nclasses=2,
|
||||
prototypes_per_class=4,
|
||||
prototype_initializer='ones')
|
||||
def test_prototypes1d_init_without_pdist(self):
|
||||
p1 = prototypes.Prototypes1D(input_dim=6,
|
||||
nclasses=2,
|
||||
prototypes_per_class=4,
|
||||
prototype_initializer='ones')
|
||||
protos = p1.prototypes
|
||||
actual = protos.detach().numpy()
|
||||
desired = torch.ones(8, 6)
|
||||
@@ -37,11 +37,11 @@ class TestPrototypes(unittest.TestCase):
|
||||
decimal=5)
|
||||
self.assertIsNone(mismatch)
|
||||
|
||||
def test_addprototypes1d_init_without_data(self):
|
||||
def test_prototypes1d_init_without_data(self):
|
||||
pdist = [2, 2]
|
||||
p1 = prototypes.AddPrototypes1D(input_dim=3,
|
||||
prototype_distribution=pdist,
|
||||
prototype_initializer='zeros')
|
||||
p1 = prototypes.Prototypes1D(input_dim=3,
|
||||
prototype_distribution=pdist,
|
||||
prototype_initializer='zeros')
|
||||
protos = p1.prototypes
|
||||
actual = protos.detach().numpy()
|
||||
desired = torch.zeros(4, 3)
|
||||
@@ -50,23 +50,20 @@ class TestPrototypes(unittest.TestCase):
|
||||
decimal=5)
|
||||
self.assertIsNone(mismatch)
|
||||
|
||||
# def test_addprototypes1d_init_torch_pdist(self):
|
||||
# pdist = torch.tensor([2, 2])
|
||||
# p1 = prototypes.AddPrototypes1D(input_dim=3,
|
||||
# prototype_distribution=pdist,
|
||||
# prototype_initializer='zeros')
|
||||
# protos = p1.prototypes
|
||||
# actual = protos.detach().numpy()
|
||||
# desired = torch.zeros(4, 3)
|
||||
# mismatch = np.testing.assert_array_almost_equal(actual,
|
||||
# desired,
|
||||
# decimal=5)
|
||||
# self.assertIsNone(mismatch)
|
||||
def test_prototypes1d_proto_init_without_data(self):
|
||||
with self.assertWarns(Warning):
|
||||
_ = prototypes.Prototypes1D(
|
||||
input_dim=3,
|
||||
nclasses=2,
|
||||
prototypes_per_class=1,
|
||||
prototype_initializer='stratified_mean',
|
||||
data=None)
|
||||
|
||||
def test_addprototypes1d_init_with_ppc(self):
|
||||
p1 = prototypes.AddPrototypes1D(data=[self.x, self.y],
|
||||
prototypes_per_class=2,
|
||||
prototype_initializer='zeros')
|
||||
def test_prototypes1d_init_torch_pdist(self):
|
||||
pdist = torch.tensor([2, 2])
|
||||
p1 = prototypes.Prototypes1D(input_dim=3,
|
||||
prototype_distribution=pdist,
|
||||
prototype_initializer='zeros')
|
||||
protos = p1.prototypes
|
||||
actual = protos.detach().numpy()
|
||||
desired = torch.zeros(4, 3)
|
||||
@@ -75,10 +72,68 @@ class TestPrototypes(unittest.TestCase):
|
||||
decimal=5)
|
||||
self.assertIsNone(mismatch)
|
||||
|
||||
def test_addprototypes1d_init_with_pdist(self):
|
||||
p1 = prototypes.AddPrototypes1D(data=[self.x, self.y],
|
||||
prototype_distribution=[6, 9],
|
||||
prototype_initializer='zeros')
|
||||
def test_prototypes1d_init_without_inputdim_with_data(self):
|
||||
_ = prototypes.Prototypes1D(nclasses=1,
|
||||
prototypes_per_class=1,
|
||||
prototype_initializer='stratified_mean',
|
||||
data=[[[1.]], [1]])
|
||||
|
||||
def test_prototypes1d_init_with_int_data(self):
|
||||
_ = prototypes.Prototypes1D(nclasses=1,
|
||||
prototypes_per_class=1,
|
||||
prototype_initializer='stratified_mean',
|
||||
data=[[[1]], [1]])
|
||||
|
||||
def test_prototypes1d_init_with_int_dtype(self):
|
||||
with self.assertRaises(RuntimeError):
|
||||
_ = prototypes.Prototypes1D(
|
||||
nclasses=1,
|
||||
prototypes_per_class=1,
|
||||
prototype_initializer='stratified_mean',
|
||||
data=[[[1]], [1]],
|
||||
dtype=torch.int32)
|
||||
|
||||
def test_prototypes1d_inputndim_with_data(self):
|
||||
with self.assertRaises(AssertionError):
|
||||
_ = prototypes.Prototypes1D(input_dim=1,
|
||||
nclasses=1,
|
||||
prototypes_per_class=1,
|
||||
data=[[1.], [1]])
|
||||
|
||||
def test_prototypes1d_inputdim_with_data(self):
|
||||
with self.assertRaises(AssertionError):
|
||||
_ = prototypes.Prototypes1D(
|
||||
input_dim=2,
|
||||
nclasses=1,
|
||||
prototypes_per_class=1,
|
||||
prototype_initializer='stratified_mean',
|
||||
data=[[[1.]], [1]])
|
||||
|
||||
def test_prototypes1d_nclasses_with_data(self):
|
||||
with self.assertRaises(AssertionError):
|
||||
_ = prototypes.Prototypes1D(
|
||||
input_dim=1,
|
||||
nclasses=1,
|
||||
prototypes_per_class=1,
|
||||
prototype_initializer='stratified_mean',
|
||||
data=[[[1.], [2.]], [1, 2]])
|
||||
|
||||
def test_prototypes1d_init_with_ppc(self):
|
||||
p1 = prototypes.Prototypes1D(data=[self.x, self.y],
|
||||
prototypes_per_class=2,
|
||||
prototype_initializer='zeros')
|
||||
protos = p1.prototypes
|
||||
actual = protos.detach().numpy()
|
||||
desired = torch.zeros(4, 3)
|
||||
mismatch = np.testing.assert_array_almost_equal(actual,
|
||||
desired,
|
||||
decimal=5)
|
||||
self.assertIsNone(mismatch)
|
||||
|
||||
def test_prototypes1d_init_with_pdist(self):
|
||||
p1 = prototypes.Prototypes1D(data=[self.x, self.y],
|
||||
prototype_distribution=[6, 9],
|
||||
prototype_initializer='zeros')
|
||||
protos = p1.prototypes
|
||||
actual = protos.detach().numpy()
|
||||
desired = torch.zeros(15, 3)
|
||||
@@ -87,14 +142,14 @@ class TestPrototypes(unittest.TestCase):
|
||||
decimal=5)
|
||||
self.assertIsNone(mismatch)
|
||||
|
||||
def test_addprototypes1d_func_initializer(self):
|
||||
def test_prototypes1d_func_initializer(self):
|
||||
def my_initializer(*args, **kwargs):
|
||||
return torch.full((2, 99), 99), torch.tensor([0, 1])
|
||||
|
||||
p1 = prototypes.AddPrototypes1D(input_dim=99,
|
||||
nclasses=2,
|
||||
prototypes_per_class=1,
|
||||
prototype_initializer=my_initializer)
|
||||
p1 = prototypes.Prototypes1D(input_dim=99,
|
||||
nclasses=2,
|
||||
prototypes_per_class=1,
|
||||
prototype_initializer=my_initializer)
|
||||
protos = p1.prototypes
|
||||
actual = protos.detach().numpy()
|
||||
desired = 99 * torch.ones(2, 99)
|
||||
@@ -103,8 +158,8 @@ class TestPrototypes(unittest.TestCase):
|
||||
decimal=5)
|
||||
self.assertIsNone(mismatch)
|
||||
|
||||
def test_addprototypes1d_forward(self):
|
||||
p1 = prototypes.AddPrototypes1D(data=[self.x, self.y])
|
||||
def test_prototypes1d_forward(self):
|
||||
p1 = prototypes.Prototypes1D(data=[self.x, self.y])
|
||||
protos, _ = p1()
|
||||
actual = protos.detach().numpy()
|
||||
desired = torch.ones(2, 3)
|
||||
|
Reference in New Issue
Block a user