Remove assert statements following codacy security recommendation
"Use of assert detected. The enclosed code will be removed when compiling to optimised byte code."
This commit is contained in:
parent
955661af95
commit
8f3a43f62a
@ -49,15 +49,25 @@ class Prototypes1D(torch.nn.Module):
|
|||||||
y_train = torch.as_tensor(y_train).type(dtype)
|
y_train = torch.as_tensor(y_train).type(dtype)
|
||||||
nclasses = torch.unique(y_train).shape[0]
|
nclasses = torch.unique(y_train).shape[0]
|
||||||
|
|
||||||
assert x_train.ndim == 2
|
if x_train.ndim != 2:
|
||||||
|
raise ValueError('`data[0].ndim != 2`.')
|
||||||
|
|
||||||
# Verify input dimension if `input_dim` is provided
|
# Verify input dimension if `input_dim` is provided
|
||||||
if 'input_dim' in kwargs:
|
if 'input_dim' in kwargs:
|
||||||
assert kwargs.pop('input_dim') == x_train.shape[1]
|
input_dim = kwargs.pop('input_dim')
|
||||||
|
if input_dim != x_train.shape[1]:
|
||||||
|
raise ValueError(f'Provided `input_dim`={input_dim} does '
|
||||||
|
'not match data dimension '
|
||||||
|
f'`data[0].shape[1]`={x_train.shape[1]}')
|
||||||
|
|
||||||
# Verify the number of classes if `nclasses` is provided
|
# Verify the number of classes if `nclasses` is provided
|
||||||
if 'nclasses' in kwargs:
|
if 'nclasses' in kwargs:
|
||||||
assert nclasses == kwargs.pop('nclasses')
|
kwargs_nclasses = kwargs.pop('nclasses')
|
||||||
|
if kwargs_nclasses != nclasses:
|
||||||
|
raise ValueError(f'Provided `nclasses={kwargs_nclasses}` does '
|
||||||
|
'not match data labels '
|
||||||
|
'`torch.unique(data[1]).shape[0]`'
|
||||||
|
f'={nclasses}')
|
||||||
|
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
@ -94,14 +94,14 @@ class TestPrototypes(unittest.TestCase):
|
|||||||
dtype=torch.int32)
|
dtype=torch.int32)
|
||||||
|
|
||||||
def test_prototypes1d_inputndim_with_data(self):
|
def test_prototypes1d_inputndim_with_data(self):
|
||||||
with self.assertRaises(AssertionError):
|
with self.assertRaises(ValueError):
|
||||||
_ = prototypes.Prototypes1D(input_dim=1,
|
_ = prototypes.Prototypes1D(input_dim=1,
|
||||||
nclasses=1,
|
nclasses=1,
|
||||||
prototypes_per_class=1,
|
prototypes_per_class=1,
|
||||||
data=[[1.], [1]])
|
data=[[1.], [1]])
|
||||||
|
|
||||||
def test_prototypes1d_inputdim_with_data(self):
|
def test_prototypes1d_inputdim_with_data(self):
|
||||||
with self.assertRaises(AssertionError):
|
with self.assertRaises(ValueError):
|
||||||
_ = prototypes.Prototypes1D(
|
_ = prototypes.Prototypes1D(
|
||||||
input_dim=2,
|
input_dim=2,
|
||||||
nclasses=1,
|
nclasses=1,
|
||||||
@ -110,7 +110,7 @@ class TestPrototypes(unittest.TestCase):
|
|||||||
data=[[[1.]], [1]])
|
data=[[[1.]], [1]])
|
||||||
|
|
||||||
def test_prototypes1d_nclasses_with_data(self):
|
def test_prototypes1d_nclasses_with_data(self):
|
||||||
with self.assertRaises(AssertionError):
|
with self.assertRaises(ValueError):
|
||||||
_ = prototypes.Prototypes1D(
|
_ = prototypes.Prototypes1D(
|
||||||
input_dim=1,
|
input_dim=1,
|
||||||
nclasses=1,
|
nclasses=1,
|
||||||
@ -194,3 +194,7 @@ class TestLosses(unittest.TestCase):
|
|||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user