diff --git a/tests/test_modules.py b/tests/test_modules.py index 1beb1a0..d9090b2 100644 --- a/tests/test_modules.py +++ b/tests/test_modules.py @@ -245,7 +245,7 @@ class TestLosses(unittest.TestCase): def test_glvqloss_init(self): _ = losses.GLVQLoss(0, 'swish_beta', beta=20) - def test_glvqloss_forward(self): + def test_glvqloss_forward_1ppc(self): criterion = losses.GLVQLoss(margin=0, squashing='sigmoid_beta', beta=100) @@ -257,5 +257,23 @@ class TestLosses(unittest.TestCase): loss_value = loss.item() self.assertAlmostEqual(loss_value, 0.0) + def test_glvqloss_forward_2ppc(self): + criterion = losses.GLVQLoss(margin=0, + squashing='sigmoid_beta', + beta=100) + d = torch.stack([ + torch.ones(100), + torch.ones(100), + torch.zeros(100), + torch.ones(100) + ], + dim=1) + labels = torch.tensor([0, 0, 1, 1]) + targets = torch.ones(100) + outputs = [d, labels] + loss = criterion(outputs, targets) + loss_value = loss.item() + self.assertAlmostEqual(loss_value, 0.0) + def tearDown(self): pass