Remove assert statements following codacy security recommendation

"Use of assert detected. The enclosed code will be removed when compiling to
optimised byte code."
This commit is contained in:
blackfly 2020-04-11 15:45:29 +02:00
parent 955661af95
commit 8f3a43f62a
2 changed files with 20 additions and 6 deletions

View File

@ -49,15 +49,25 @@ class Prototypes1D(torch.nn.Module):
y_train = torch.as_tensor(y_train).type(dtype)
nclasses = torch.unique(y_train).shape[0]
assert x_train.ndim == 2
if x_train.ndim != 2:
raise ValueError('`data[0].ndim != 2`.')
# Verify input dimension if `input_dim` is provided
if 'input_dim' in kwargs:
assert kwargs.pop('input_dim') == x_train.shape[1]
input_dim = kwargs.pop('input_dim')
if input_dim != x_train.shape[1]:
raise ValueError(f'Provided `input_dim`={input_dim} does '
'not match data dimension '
f'`data[0].shape[1]`={x_train.shape[1]}')
# Verify the number of classes if `nclasses` is provided
if 'nclasses' in kwargs:
assert nclasses == kwargs.pop('nclasses')
kwargs_nclasses = kwargs.pop('nclasses')
if kwargs_nclasses != nclasses:
raise ValueError(f'Provided `nclasses={kwargs_nclasses}` does '
'not match data labels '
'`torch.unique(data[1]).shape[0]`'
f'={nclasses}')
super().__init__(**kwargs)

View File

@ -94,14 +94,14 @@ class TestPrototypes(unittest.TestCase):
dtype=torch.int32)
def test_prototypes1d_inputndim_with_data(self):
with self.assertRaises(AssertionError):
with self.assertRaises(ValueError):
_ = prototypes.Prototypes1D(input_dim=1,
nclasses=1,
prototypes_per_class=1,
data=[[1.], [1]])
def test_prototypes1d_inputdim_with_data(self):
with self.assertRaises(AssertionError):
with self.assertRaises(ValueError):
_ = prototypes.Prototypes1D(
input_dim=2,
nclasses=1,
@ -110,7 +110,7 @@ class TestPrototypes(unittest.TestCase):
data=[[[1.]], [1]])
def test_prototypes1d_nclasses_with_data(self):
with self.assertRaises(AssertionError):
with self.assertRaises(ValueError):
_ = prototypes.Prototypes1D(
input_dim=1,
nclasses=1,
@ -194,3 +194,7 @@ class TestLosses(unittest.TestCase):
def tearDown(self):
pass
if __name__ == '__main__':
unittest.main()