Remove assert statements following codacy security recommendation
"Use of assert detected. The enclosed code will be removed when compiling to optimised byte code."
This commit is contained in:
parent
955661af95
commit
8f3a43f62a
@ -49,15 +49,25 @@ class Prototypes1D(torch.nn.Module):
|
||||
y_train = torch.as_tensor(y_train).type(dtype)
|
||||
nclasses = torch.unique(y_train).shape[0]
|
||||
|
||||
assert x_train.ndim == 2
|
||||
if x_train.ndim != 2:
|
||||
raise ValueError('`data[0].ndim != 2`.')
|
||||
|
||||
# Verify input dimension if `input_dim` is provided
|
||||
if 'input_dim' in kwargs:
|
||||
assert kwargs.pop('input_dim') == x_train.shape[1]
|
||||
input_dim = kwargs.pop('input_dim')
|
||||
if input_dim != x_train.shape[1]:
|
||||
raise ValueError(f'Provided `input_dim`={input_dim} does '
|
||||
'not match data dimension '
|
||||
f'`data[0].shape[1]`={x_train.shape[1]}')
|
||||
|
||||
# Verify the number of classes if `nclasses` is provided
|
||||
if 'nclasses' in kwargs:
|
||||
assert nclasses == kwargs.pop('nclasses')
|
||||
kwargs_nclasses = kwargs.pop('nclasses')
|
||||
if kwargs_nclasses != nclasses:
|
||||
raise ValueError(f'Provided `nclasses={kwargs_nclasses}` does '
|
||||
'not match data labels '
|
||||
'`torch.unique(data[1]).shape[0]`'
|
||||
f'={nclasses}')
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
||||
|
@ -94,14 +94,14 @@ class TestPrototypes(unittest.TestCase):
|
||||
dtype=torch.int32)
|
||||
|
||||
def test_prototypes1d_inputndim_with_data(self):
|
||||
with self.assertRaises(AssertionError):
|
||||
with self.assertRaises(ValueError):
|
||||
_ = prototypes.Prototypes1D(input_dim=1,
|
||||
nclasses=1,
|
||||
prototypes_per_class=1,
|
||||
data=[[1.], [1]])
|
||||
|
||||
def test_prototypes1d_inputdim_with_data(self):
|
||||
with self.assertRaises(AssertionError):
|
||||
with self.assertRaises(ValueError):
|
||||
_ = prototypes.Prototypes1D(
|
||||
input_dim=2,
|
||||
nclasses=1,
|
||||
@ -110,7 +110,7 @@ class TestPrototypes(unittest.TestCase):
|
||||
data=[[[1.]], [1]])
|
||||
|
||||
def test_prototypes1d_nclasses_with_data(self):
|
||||
with self.assertRaises(AssertionError):
|
||||
with self.assertRaises(ValueError):
|
||||
_ = prototypes.Prototypes1D(
|
||||
input_dim=1,
|
||||
nclasses=1,
|
||||
@ -194,3 +194,7 @@ class TestLosses(unittest.TestCase):
|
||||
|
||||
def tearDown(self):
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
Loading…
Reference in New Issue
Block a user