fix: incorrect variable names in GLVQLoss.forward

This commit is contained in:
Jensun Ravichandran 2022-03-09 13:20:00 +01:00
parent a54acdef22
commit 695559fd4a
No known key found for this signature in database
GPG Key ID: 4E9348239810B51F

View File

@ -120,8 +120,8 @@ class GLVQLoss(torch.nn.Module):
self.add_dp = add_dp self.add_dp = add_dp
def forward(self, outputs, targets, plabels): def forward(self, outputs, targets, plabels):
# mu = glvq_loss(outputs, targets, prototype_labels=plabels) # mu = glvq_loss(outputs, targets, plabels)
dp, dm = _get_dp_dm(distances, target_labels, prototype_labels) dp, dm = _get_dp_dm(outputs, targets, plabels)
mu = (dp - dm) / (dp + dm) mu = (dp - dm) / (dp + dm)
if self.add_dp: if self.add_dp:
mu = mu + dp mu = mu + dp