fix: incorrect variable names in GLVQLoss.forward
This commit is contained in:
parent
a54acdef22
commit
695559fd4a
@ -120,8 +120,8 @@ class GLVQLoss(torch.nn.Module):
|
||||
self.add_dp = add_dp
|
||||
|
||||
def forward(self, outputs, targets, plabels):
|
||||
# mu = glvq_loss(outputs, targets, prototype_labels=plabels)
|
||||
dp, dm = _get_dp_dm(distances, target_labels, prototype_labels)
|
||||
# mu = glvq_loss(outputs, targets, plabels)
|
||||
dp, dm = _get_dp_dm(outputs, targets, plabels)
|
||||
mu = (dp - dm) / (dp + dm)
|
||||
if self.add_dp:
|
||||
mu = mu + dp
|
||||
|
Loading…
Reference in New Issue
Block a user