feat: update GLVQLoss to include a regularization term

This commit is contained in:
Jensun Ravichandran 2022-02-15 17:16:44 +01:00
parent bebd13868f
commit a54acdef22
No known key found for this signature in database
GPG Key ID: 7612C0CAB643D921

View File

@ -106,19 +106,31 @@ def margin_loss(y_pred, y_true, margin=0.3):
class GLVQLoss(torch.nn.Module): class GLVQLoss(torch.nn.Module):
def __init__(self, margin=0.0, transfer_fn="identity", beta=10, **kwargs):
def __init__(self,
margin=0.0,
transfer_fn="identity",
beta=10,
add_dp=False,
**kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.margin = margin self.margin = margin
self.transfer_fn = get_activation(transfer_fn) self.transfer_fn = get_activation(transfer_fn)
self.beta = torch.tensor(beta) self.beta = torch.tensor(beta)
self.add_dp = add_dp
def forward(self, outputs, targets, plabels): def forward(self, outputs, targets, plabels):
mu = glvq_loss(outputs, targets, prototype_labels=plabels) # mu = glvq_loss(outputs, targets, prototype_labels=plabels)
dp, dm = _get_dp_dm(distances, target_labels, prototype_labels)
mu = (dp - dm) / (dp + dm)
if self.add_dp:
mu = mu + dp
batch_loss = self.transfer_fn(mu + self.margin, beta=self.beta) batch_loss = self.transfer_fn(mu + self.margin, beta=self.beta)
return batch_loss.sum() return batch_loss.sum()
class MarginLoss(torch.nn.modules.loss._Loss): class MarginLoss(torch.nn.modules.loss._Loss):
def __init__(self, def __init__(self,
margin=0.3, margin=0.3,
size_average=None, size_average=None,
@ -132,6 +144,7 @@ class MarginLoss(torch.nn.modules.loss._Loss):
class NeuralGasEnergy(torch.nn.Module): class NeuralGasEnergy(torch.nn.Module):
def __init__(self, lm, **kwargs): def __init__(self, lm, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.lm = lm self.lm = lm
@ -152,6 +165,7 @@ class NeuralGasEnergy(torch.nn.Module):
class GrowingNeuralGasEnergy(NeuralGasEnergy): class GrowingNeuralGasEnergy(NeuralGasEnergy):
def __init__(self, topology_layer, **kwargs): def __init__(self, topology_layer, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.topology_layer = topology_layer self.topology_layer = topology_layer