fix: incorrect variable names in GLVQLoss.forward
This commit is contained in:
parent
a54acdef22
commit
695559fd4a
@ -120,8 +120,8 @@ class GLVQLoss(torch.nn.Module):
|
|||||||
self.add_dp = add_dp
|
self.add_dp = add_dp
|
||||||
|
|
||||||
def forward(self, outputs, targets, plabels):
|
def forward(self, outputs, targets, plabels):
|
||||||
# mu = glvq_loss(outputs, targets, prototype_labels=plabels)
|
# mu = glvq_loss(outputs, targets, plabels)
|
||||||
dp, dm = _get_dp_dm(distances, target_labels, prototype_labels)
|
dp, dm = _get_dp_dm(outputs, targets, plabels)
|
||||||
mu = (dp - dm) / (dp + dm)
|
mu = (dp - dm) / (dp + dm)
|
||||||
if self.add_dp:
|
if self.add_dp:
|
||||||
mu = mu + dp
|
mu = mu + dp
|
||||||
|
Loading…
Reference in New Issue
Block a user