refactor: refactor GLVQLoss
This commit is contained in:
parent
fdb9a7c66d
commit
99be965581
@ -106,17 +106,16 @@ def margin_loss(y_pred, y_true, margin=0.3):
|
||||
|
||||
|
||||
class GLVQLoss(torch.nn.Module):
|
||||
def __init__(self, margin=0.0, squashing="identity", beta=10, **kwargs):
|
||||
def __init__(self, margin=0.0, transfer_fn="identity", beta=10, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.margin = margin
|
||||
self.squashing = get_activation(squashing)
|
||||
self.transfer_fn = get_activation(transfer_fn)
|
||||
self.beta = torch.tensor(beta)
|
||||
|
||||
def forward(self, outputs, targets):
|
||||
distances, plabels = outputs
|
||||
mu = glvq_loss(distances, targets, prototype_labels=plabels)
|
||||
batch_loss = self.squashing(mu + self.margin, beta=self.beta)
|
||||
return torch.sum(batch_loss, dim=0)
|
||||
def forward(self, outputs, targets, plabels):
|
||||
mu = glvq_loss(outputs, targets, prototype_labels=plabels)
|
||||
batch_loss = self.transfer_fn(mu + self.margin, beta=self.beta)
|
||||
return batch_loss.sum()
|
||||
|
||||
|
||||
class MarginLoss(torch.nn.modules.loss._Loss):
|
||||
|
Loading…
Reference in New Issue
Block a user