From 695559fd4a0bdcb0206dc76f082b20ee0910bf23 Mon Sep 17 00:00:00 2001 From: Jensun Ravichandran Date: Wed, 9 Mar 2022 13:20:00 +0100 Subject: [PATCH] fix: incorrect variable names in `GLVQLoss.forward` --- prototorch/core/losses.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/prototorch/core/losses.py b/prototorch/core/losses.py index c92424d..873538a 100644 --- a/prototorch/core/losses.py +++ b/prototorch/core/losses.py @@ -120,8 +120,8 @@ class GLVQLoss(torch.nn.Module): self.add_dp = add_dp def forward(self, outputs, targets, plabels): - # mu = glvq_loss(outputs, targets, prototype_labels=plabels) - dp, dm = _get_dp_dm(distances, target_labels, prototype_labels) + # mu = glvq_loss(outputs, targets, plabels) + dp, dm = _get_dp_dm(outputs, targets, plabels) mu = (dp - dm) / (dp + dm) if self.add_dp: mu = mu + dp