Update tests/test_modules.py
This commit is contained in:
parent
6e72b9267a
commit
fa72c7156e
@ -245,7 +245,7 @@ class TestLosses(unittest.TestCase):
|
|||||||
def test_glvqloss_init(self):
|
def test_glvqloss_init(self):
|
||||||
_ = losses.GLVQLoss(0, 'swish_beta', beta=20)
|
_ = losses.GLVQLoss(0, 'swish_beta', beta=20)
|
||||||
|
|
||||||
def test_glvqloss_forward(self):
|
def test_glvqloss_forward_1ppc(self):
|
||||||
criterion = losses.GLVQLoss(margin=0,
|
criterion = losses.GLVQLoss(margin=0,
|
||||||
squashing='sigmoid_beta',
|
squashing='sigmoid_beta',
|
||||||
beta=100)
|
beta=100)
|
||||||
@ -257,5 +257,23 @@ class TestLosses(unittest.TestCase):
|
|||||||
loss_value = loss.item()
|
loss_value = loss.item()
|
||||||
self.assertAlmostEqual(loss_value, 0.0)
|
self.assertAlmostEqual(loss_value, 0.0)
|
||||||
|
|
||||||
|
def test_glvqloss_forward_2ppc(self):
|
||||||
|
criterion = losses.GLVQLoss(margin=0,
|
||||||
|
squashing='sigmoid_beta',
|
||||||
|
beta=100)
|
||||||
|
d = torch.stack([
|
||||||
|
torch.ones(100),
|
||||||
|
torch.ones(100),
|
||||||
|
torch.zeros(100),
|
||||||
|
torch.ones(100)
|
||||||
|
],
|
||||||
|
dim=1)
|
||||||
|
labels = torch.tensor([0, 0, 1, 1])
|
||||||
|
targets = torch.ones(100)
|
||||||
|
outputs = [d, labels]
|
||||||
|
loss = criterion(outputs, targets)
|
||||||
|
loss_value = loss.item()
|
||||||
|
self.assertAlmostEqual(loss_value, 0.0)
|
||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
pass
|
pass
|
||||||
|
Loading…
Reference in New Issue
Block a user