diff --git a/prototorch/functions/losses.py b/prototorch/functions/losses.py index bac44b6..bb8b9e4 100644 --- a/prototorch/functions/losses.py +++ b/prototorch/functions/losses.py @@ -12,12 +12,9 @@ def glvq_loss(distances, target_labels, prototype_labels): matcher = torch.eq(torch.sum(matcher, dim=-1), nclasses) not_matcher = torch.bitwise_not(matcher) - dplus_criterion = distances * matcher > 0.0 - dminus_criterion = distances * not_matcher > 0.0 - inf = torch.full_like(distances, fill_value=float('inf')) - distances_to_wpluses = torch.where(dplus_criterion, distances, inf) - distances_to_wminuses = torch.where(dminus_criterion, distances, inf) + distances_to_wpluses = torch.where(matcher, distances, inf) + distances_to_wminuses = torch.where(not_matcher, distances, inf) dpluses = torch.min(distances_to_wpluses, dim=1, keepdim=True).values dminuses = torch.min(distances_to_wminuses, dim=1, keepdim=True).values diff --git a/prototorch/modules/losses.py b/prototorch/modules/losses.py index 468c407..c3e624b 100644 --- a/prototorch/modules/losses.py +++ b/prototorch/modules/losses.py @@ -12,7 +12,7 @@ class GLVQLoss(torch.nn.Module): super().__init__(**kwargs) self.margin = margin self.squashing = get_activation(squashing) - self.beta = beta + self.beta = torch.tensor(beta) def forward(self, outputs, targets): distances, plabels = outputs