diff --git a/prototorch/core/losses.py b/prototorch/core/losses.py index 1a32103..f413de5 100644 --- a/prototorch/core/losses.py +++ b/prototorch/core/losses.py @@ -106,17 +106,16 @@ def margin_loss(y_pred, y_true, margin=0.3): class GLVQLoss(torch.nn.Module): - def __init__(self, margin=0.0, squashing="identity", beta=10, **kwargs): + def __init__(self, margin=0.0, transfer_fn="identity", beta=10, **kwargs): super().__init__(**kwargs) self.margin = margin - self.squashing = get_activation(squashing) + self.transfer_fn = get_activation(transfer_fn) self.beta = torch.tensor(beta) - def forward(self, outputs, targets): - distances, plabels = outputs - mu = glvq_loss(distances, targets, prototype_labels=plabels) - batch_loss = self.squashing(mu + self.margin, beta=self.beta) - return torch.sum(batch_loss, dim=0) + def forward(self, outputs, targets, plabels): + mu = glvq_loss(outputs, targets, prototype_labels=plabels) + batch_loss = self.transfer_fn(mu + self.margin, beta=self.beta) + return batch_loss.sum() class MarginLoss(torch.nn.modules.loss._Loss):