2020-04-06 16:36:28 +02:00

22 lines
668 B
Python

"""ProtoTorch losses."""
import torch
from prototorch.functions.activations import get_activation
from prototorch.functions.losses import glvq_loss
class GLVQLoss(torch.nn.Module):
"""GLVQ Loss."""
def __init__(self, margin=0.0, squashing='identity', beta=10, **kwargs):
super().__init__(**kwargs)
self.margin = margin
self.squashing = get_activation(squashing)
self.beta = beta
def forward(self, outputs, targets):
distances, plabels = outputs
mu = glvq_loss(distances, targets, plabels)
batch_loss = self.squashing(mu + self.margin, beta=self.beta)
return torch.sum(batch_loss, dim=0)