diff --git a/prototorch/core/losses.py b/prototorch/core/losses.py index c92424d..873538a 100644 --- a/prototorch/core/losses.py +++ b/prototorch/core/losses.py @@ -120,8 +120,8 @@ class GLVQLoss(torch.nn.Module): self.add_dp = add_dp def forward(self, outputs, targets, plabels): - # mu = glvq_loss(outputs, targets, prototype_labels=plabels) - dp, dm = _get_dp_dm(distances, target_labels, prototype_labels) + # mu = glvq_loss(outputs, targets, plabels) + dp, dm = _get_dp_dm(outputs, targets, plabels) mu = (dp - dm) / (dp + dm) if self.add_dp: mu = mu + dp