Update tests/test_modules.py
This commit is contained in:
parent
6e72b9267a
commit
fa72c7156e
@ -245,7 +245,7 @@ class TestLosses(unittest.TestCase):
|
||||
def test_glvqloss_init(self):
|
||||
_ = losses.GLVQLoss(0, 'swish_beta', beta=20)
|
||||
|
||||
def test_glvqloss_forward(self):
|
||||
def test_glvqloss_forward_1ppc(self):
|
||||
criterion = losses.GLVQLoss(margin=0,
|
||||
squashing='sigmoid_beta',
|
||||
beta=100)
|
||||
@ -257,5 +257,23 @@ class TestLosses(unittest.TestCase):
|
||||
loss_value = loss.item()
|
||||
self.assertAlmostEqual(loss_value, 0.0)
|
||||
|
||||
def test_glvqloss_forward_2ppc(self):
|
||||
criterion = losses.GLVQLoss(margin=0,
|
||||
squashing='sigmoid_beta',
|
||||
beta=100)
|
||||
d = torch.stack([
|
||||
torch.ones(100),
|
||||
torch.ones(100),
|
||||
torch.zeros(100),
|
||||
torch.ones(100)
|
||||
],
|
||||
dim=1)
|
||||
labels = torch.tensor([0, 0, 1, 1])
|
||||
targets = torch.ones(100)
|
||||
outputs = [d, labels]
|
||||
loss = criterion(outputs, targets)
|
||||
loss_value = loss.item()
|
||||
self.assertAlmostEqual(loss_value, 0.0)
|
||||
|
||||
def tearDown(self):
|
||||
pass
|
||||
|
Loading…
Reference in New Issue
Block a user