refactor: refactor GLVQLoss

This commit is contained in:
Jensun Ravichandran 2021-07-06 17:01:28 +02:00
parent fdb9a7c66d
commit 99be965581
No known key found for this signature in database
GPG Key ID: 3331B0F18B6D4D93

View File

@ -106,17 +106,16 @@ def margin_loss(y_pred, y_true, margin=0.3):
class GLVQLoss(torch.nn.Module):
def __init__(self, margin=0.0, squashing="identity", beta=10, **kwargs):
def __init__(self, margin=0.0, transfer_fn="identity", beta=10, **kwargs):
super().__init__(**kwargs)
self.margin = margin
self.squashing = get_activation(squashing)
self.transfer_fn = get_activation(transfer_fn)
self.beta = torch.tensor(beta)
def forward(self, outputs, targets):
distances, plabels = outputs
mu = glvq_loss(distances, targets, prototype_labels=plabels)
batch_loss = self.squashing(mu + self.margin, beta=self.beta)
return torch.sum(batch_loss, dim=0)
def forward(self, outputs, targets, plabels):
mu = glvq_loss(outputs, targets, prototype_labels=plabels)
batch_loss = self.transfer_fn(mu + self.margin, beta=self.beta)
return batch_loss.sum()
class MarginLoss(torch.nn.modules.loss._Loss):