From a54acdef226a3124ad69483a5bd35a46a26b3046 Mon Sep 17 00:00:00 2001 From: Jensun Ravichandran Date: Tue, 15 Feb 2022 17:16:44 +0100 Subject: [PATCH] feat: update GLVQLoss to include a regularization term --- prototorch/core/losses.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/prototorch/core/losses.py b/prototorch/core/losses.py index c0977d3..c92424d 100644 --- a/prototorch/core/losses.py +++ b/prototorch/core/losses.py @@ -106,19 +106,31 @@ def margin_loss(y_pred, y_true, margin=0.3): class GLVQLoss(torch.nn.Module): - def __init__(self, margin=0.0, transfer_fn="identity", beta=10, **kwargs): + + def __init__(self, + margin=0.0, + transfer_fn="identity", + beta=10, + add_dp=False, + **kwargs): super().__init__(**kwargs) self.margin = margin self.transfer_fn = get_activation(transfer_fn) self.beta = torch.tensor(beta) + self.add_dp = add_dp def forward(self, outputs, targets, plabels): - mu = glvq_loss(outputs, targets, prototype_labels=plabels) + # mu = glvq_loss(outputs, targets, prototype_labels=plabels) + dp, dm = _get_dp_dm(distances, target_labels, prototype_labels) + mu = (dp - dm) / (dp + dm) + if self.add_dp: + mu = mu + dp batch_loss = self.transfer_fn(mu + self.margin, beta=self.beta) return batch_loss.sum() class MarginLoss(torch.nn.modules.loss._Loss): + def __init__(self, margin=0.3, size_average=None, @@ -132,6 +144,7 @@ class MarginLoss(torch.nn.modules.loss._Loss): class NeuralGasEnergy(torch.nn.Module): + def __init__(self, lm, **kwargs): super().__init__(**kwargs) self.lm = lm @@ -152,6 +165,7 @@ class NeuralGasEnergy(torch.nn.Module): class GrowingNeuralGasEnergy(NeuralGasEnergy): + def __init__(self, topology_layer, **kwargs): super().__init__(**kwargs) self.topology_layer = topology_layer