Update tests/test_modules.py

This commit is contained in:
Jensun Ravichandran 2020-07-13 09:32:12 +02:00
parent 6e72b9267a
commit fa72c7156e

View File

@ -245,7 +245,7 @@ class TestLosses(unittest.TestCase):
def test_glvqloss_init(self): def test_glvqloss_init(self):
_ = losses.GLVQLoss(0, 'swish_beta', beta=20) _ = losses.GLVQLoss(0, 'swish_beta', beta=20)
def test_glvqloss_forward(self): def test_glvqloss_forward_1ppc(self):
criterion = losses.GLVQLoss(margin=0, criterion = losses.GLVQLoss(margin=0,
squashing='sigmoid_beta', squashing='sigmoid_beta',
beta=100) beta=100)
@ -257,5 +257,23 @@ class TestLosses(unittest.TestCase):
loss_value = loss.item() loss_value = loss.item()
self.assertAlmostEqual(loss_value, 0.0) self.assertAlmostEqual(loss_value, 0.0)
def test_glvqloss_forward_2ppc(self):
criterion = losses.GLVQLoss(margin=0,
squashing='sigmoid_beta',
beta=100)
d = torch.stack([
torch.ones(100),
torch.ones(100),
torch.zeros(100),
torch.ones(100)
],
dim=1)
labels = torch.tensor([0, 0, 1, 1])
targets = torch.ones(100)
outputs = [d, labels]
loss = criterion(outputs, targets)
loss_value = loss.item()
self.assertAlmostEqual(loss_value, 0.0)
def tearDown(self): def tearDown(self):
pass pass