"""ProtoTorch modules test suite.""" import unittest import numpy as np import torch from prototorch.modules import prototypes, losses class TestPrototypes(unittest.TestCase): def setUp(self): self.x = torch.tensor( [[0, -1, -2], [10, 11, 12], [0, 0, 0], [2, 2, 2]], dtype=torch.float32) self.y = torch.tensor([0, 0, 1, 1]) self.gen = torch.manual_seed(42) def test_prototypes1d_init_without_input_dim(self): with self.assertRaises(NameError): _ = prototypes.Prototypes1D(nclasses=1) def test_prototypes1d_init_without_nclasses(self): with self.assertRaises(NameError): _ = prototypes.Prototypes1D(input_dim=1) def test_prototypes1d_init_without_pdist(self): p1 = prototypes.Prototypes1D(input_dim=6, nclasses=2, prototypes_per_class=4, prototype_initializer='ones') protos = p1.prototypes actual = protos.detach().numpy() desired = torch.ones(8, 6) mismatch = np.testing.assert_array_almost_equal(actual, desired, decimal=5) self.assertIsNone(mismatch) def test_prototypes1d_init_without_data(self): pdist = [2, 2] p1 = prototypes.Prototypes1D(input_dim=3, prototype_distribution=pdist, prototype_initializer='zeros') protos = p1.prototypes actual = protos.detach().numpy() desired = torch.zeros(4, 3) mismatch = np.testing.assert_array_almost_equal(actual, desired, decimal=5) self.assertIsNone(mismatch) def test_prototypes1d_proto_init_without_data(self): with self.assertWarns(Warning): _ = prototypes.Prototypes1D( input_dim=3, nclasses=2, prototypes_per_class=1, prototype_initializer='stratified_mean', data=None) def test_prototypes1d_init_torch_pdist(self): pdist = torch.tensor([2, 2]) p1 = prototypes.Prototypes1D(input_dim=3, prototype_distribution=pdist, prototype_initializer='zeros') protos = p1.prototypes actual = protos.detach().numpy() desired = torch.zeros(4, 3) mismatch = np.testing.assert_array_almost_equal(actual, desired, decimal=5) self.assertIsNone(mismatch) def test_prototypes1d_init_without_inputdim_with_data(self): _ = prototypes.Prototypes1D(nclasses=1, prototypes_per_class=1, prototype_initializer='stratified_mean', data=[[[1.]], [1]]) def test_prototypes1d_init_with_int_data(self): _ = prototypes.Prototypes1D(nclasses=1, prototypes_per_class=1, prototype_initializer='stratified_mean', data=[[[1]], [1]]) def test_prototypes1d_init_with_int_dtype(self): with self.assertRaises(RuntimeError): _ = prototypes.Prototypes1D( nclasses=1, prototypes_per_class=1, prototype_initializer='stratified_mean', data=[[[1]], [1]], dtype=torch.int32) def test_prototypes1d_inputndim_with_data(self): with self.assertRaises(ValueError): _ = prototypes.Prototypes1D(input_dim=1, nclasses=1, prototypes_per_class=1, data=[[1.], [1]]) def test_prototypes1d_inputdim_with_data(self): with self.assertRaises(ValueError): _ = prototypes.Prototypes1D( input_dim=2, nclasses=1, prototypes_per_class=1, prototype_initializer='stratified_mean', data=[[[1.]], [1]]) def test_prototypes1d_nclasses_with_data(self): with self.assertRaises(ValueError): _ = prototypes.Prototypes1D( input_dim=1, nclasses=1, prototypes_per_class=1, prototype_initializer='stratified_mean', data=[[[1.], [2.]], [1, 2]]) def test_prototypes1d_init_with_ppc(self): p1 = prototypes.Prototypes1D(data=[self.x, self.y], prototypes_per_class=2, prototype_initializer='zeros') protos = p1.prototypes actual = protos.detach().numpy() desired = torch.zeros(4, 3) mismatch = np.testing.assert_array_almost_equal(actual, desired, decimal=5) self.assertIsNone(mismatch) def test_prototypes1d_init_with_pdist(self): p1 = prototypes.Prototypes1D(data=[self.x, self.y], prototype_distribution=[6, 9], prototype_initializer='zeros') protos = p1.prototypes actual = protos.detach().numpy() desired = torch.zeros(15, 3) mismatch = np.testing.assert_array_almost_equal(actual, desired, decimal=5) self.assertIsNone(mismatch) def test_prototypes1d_func_initializer(self): def my_initializer(*args, **kwargs): return torch.full((2, 99), 99), torch.tensor([0, 1]) p1 = prototypes.Prototypes1D(input_dim=99, nclasses=2, prototypes_per_class=1, prototype_initializer=my_initializer) protos = p1.prototypes actual = protos.detach().numpy() desired = 99 * torch.ones(2, 99) mismatch = np.testing.assert_array_almost_equal(actual, desired, decimal=5) self.assertIsNone(mismatch) def test_prototypes1d_forward(self): p1 = prototypes.Prototypes1D(data=[self.x, self.y]) protos, _ = p1() actual = protos.detach().numpy() desired = torch.ones(2, 3) mismatch = np.testing.assert_array_almost_equal(actual, desired, decimal=5) self.assertIsNone(mismatch) def tearDown(self): del self.x, self.y, self.gen _ = torch.seed() class TestLosses(unittest.TestCase): def setUp(self): pass def test_glvqloss_init(self): _ = losses.GLVQLoss(0, 'swish_beta', beta=20) def test_glvqloss_forward(self): criterion = losses.GLVQLoss(margin=0, squashing='sigmoid_beta', beta=100) d = torch.stack([torch.ones(100), torch.zeros(100)], dim=1) labels = torch.tensor([0, 1]) targets = torch.ones(100) outputs = [d, labels] loss = criterion(outputs, targets) loss_value = loss.item() self.assertAlmostEqual(loss_value, 0.0) def tearDown(self): pass if __name__ == '__main__': unittest.main()