From 99be9655819e6e2baf9e1cf47a784aa329289183 Mon Sep 17 00:00:00 2001 From: Jensun Ravichandran Date: Tue, 6 Jul 2021 17:01:28 +0200 Subject: [PATCH] refactor: refactor `GLVQLoss` --- prototorch/core/losses.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/prototorch/core/losses.py b/prototorch/core/losses.py index 1a32103..f413de5 100644 --- a/prototorch/core/losses.py +++ b/prototorch/core/losses.py @@ -106,17 +106,16 @@ def margin_loss(y_pred, y_true, margin=0.3): class GLVQLoss(torch.nn.Module): - def __init__(self, margin=0.0, squashing="identity", beta=10, **kwargs): + def __init__(self, margin=0.0, transfer_fn="identity", beta=10, **kwargs): super().__init__(**kwargs) self.margin = margin - self.squashing = get_activation(squashing) + self.transfer_fn = get_activation(transfer_fn) self.beta = torch.tensor(beta) - def forward(self, outputs, targets): - distances, plabels = outputs - mu = glvq_loss(distances, targets, prototype_labels=plabels) - batch_loss = self.squashing(mu + self.margin, beta=self.beta) - return torch.sum(batch_loss, dim=0) + def forward(self, outputs, targets, plabels): + mu = glvq_loss(outputs, targets, prototype_labels=plabels) + batch_loss = self.transfer_fn(mu + self.margin, beta=self.beta) + return batch_loss.sum() class MarginLoss(torch.nn.modules.loss._Loss):