diff --git a/prototorch/modules/prototypes.py b/prototorch/modules/prototypes.py index 5a57e44..081a7b8 100644 --- a/prototorch/modules/prototypes.py +++ b/prototorch/modules/prototypes.py @@ -49,15 +49,25 @@ class Prototypes1D(torch.nn.Module): y_train = torch.as_tensor(y_train).type(dtype) nclasses = torch.unique(y_train).shape[0] - assert x_train.ndim == 2 + if x_train.ndim != 2: + raise ValueError('`data[0].ndim != 2`.') # Verify input dimension if `input_dim` is provided if 'input_dim' in kwargs: - assert kwargs.pop('input_dim') == x_train.shape[1] + input_dim = kwargs.pop('input_dim') + if input_dim != x_train.shape[1]: + raise ValueError(f'Provided `input_dim`={input_dim} does ' + 'not match data dimension ' + f'`data[0].shape[1]`={x_train.shape[1]}') # Verify the number of classes if `nclasses` is provided if 'nclasses' in kwargs: - assert nclasses == kwargs.pop('nclasses') + kwargs_nclasses = kwargs.pop('nclasses') + if kwargs_nclasses != nclasses: + raise ValueError(f'Provided `nclasses={kwargs_nclasses}` does ' + 'not match data labels ' + '`torch.unique(data[1]).shape[0]`' + f'={nclasses}') super().__init__(**kwargs) diff --git a/tests/test_modules.py b/tests/test_modules.py index 97dc79d..bf04654 100644 --- a/tests/test_modules.py +++ b/tests/test_modules.py @@ -94,14 +94,14 @@ class TestPrototypes(unittest.TestCase): dtype=torch.int32) def test_prototypes1d_inputndim_with_data(self): - with self.assertRaises(AssertionError): + with self.assertRaises(ValueError): _ = prototypes.Prototypes1D(input_dim=1, nclasses=1, prototypes_per_class=1, data=[[1.], [1]]) def test_prototypes1d_inputdim_with_data(self): - with self.assertRaises(AssertionError): + with self.assertRaises(ValueError): _ = prototypes.Prototypes1D( input_dim=2, nclasses=1, @@ -110,7 +110,7 @@ class TestPrototypes(unittest.TestCase): data=[[[1.]], [1]]) def test_prototypes1d_nclasses_with_data(self): - with self.assertRaises(AssertionError): + with self.assertRaises(ValueError): _ = prototypes.Prototypes1D( input_dim=1, nclasses=1, @@ -194,3 +194,7 @@ class TestLosses(unittest.TestCase): def tearDown(self): pass + + +if __name__ == '__main__': + unittest.main()