feat: update GLVQLoss to include a regularization term
This commit is contained in:
parent
bebd13868f
commit
a54acdef22
@ -106,19 +106,31 @@ def margin_loss(y_pred, y_true, margin=0.3):
|
|||||||
|
|
||||||
|
|
||||||
class GLVQLoss(torch.nn.Module):
|
class GLVQLoss(torch.nn.Module):
|
||||||
def __init__(self, margin=0.0, transfer_fn="identity", beta=10, **kwargs):
|
|
||||||
|
def __init__(self,
|
||||||
|
margin=0.0,
|
||||||
|
transfer_fn="identity",
|
||||||
|
beta=10,
|
||||||
|
add_dp=False,
|
||||||
|
**kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self.margin = margin
|
self.margin = margin
|
||||||
self.transfer_fn = get_activation(transfer_fn)
|
self.transfer_fn = get_activation(transfer_fn)
|
||||||
self.beta = torch.tensor(beta)
|
self.beta = torch.tensor(beta)
|
||||||
|
self.add_dp = add_dp
|
||||||
|
|
||||||
def forward(self, outputs, targets, plabels):
|
def forward(self, outputs, targets, plabels):
|
||||||
mu = glvq_loss(outputs, targets, prototype_labels=plabels)
|
# mu = glvq_loss(outputs, targets, prototype_labels=plabels)
|
||||||
|
dp, dm = _get_dp_dm(distances, target_labels, prototype_labels)
|
||||||
|
mu = (dp - dm) / (dp + dm)
|
||||||
|
if self.add_dp:
|
||||||
|
mu = mu + dp
|
||||||
batch_loss = self.transfer_fn(mu + self.margin, beta=self.beta)
|
batch_loss = self.transfer_fn(mu + self.margin, beta=self.beta)
|
||||||
return batch_loss.sum()
|
return batch_loss.sum()
|
||||||
|
|
||||||
|
|
||||||
class MarginLoss(torch.nn.modules.loss._Loss):
|
class MarginLoss(torch.nn.modules.loss._Loss):
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
margin=0.3,
|
margin=0.3,
|
||||||
size_average=None,
|
size_average=None,
|
||||||
@ -132,6 +144,7 @@ class MarginLoss(torch.nn.modules.loss._Loss):
|
|||||||
|
|
||||||
|
|
||||||
class NeuralGasEnergy(torch.nn.Module):
|
class NeuralGasEnergy(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self, lm, **kwargs):
|
def __init__(self, lm, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self.lm = lm
|
self.lm = lm
|
||||||
@ -152,6 +165,7 @@ class NeuralGasEnergy(torch.nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class GrowingNeuralGasEnergy(NeuralGasEnergy):
|
class GrowingNeuralGasEnergy(NeuralGasEnergy):
|
||||||
|
|
||||||
def __init__(self, topology_layer, **kwargs):
|
def __init__(self, topology_layer, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self.topology_layer = topology_layer
|
self.topology_layer = topology_layer
|
||||||
|
Loading…
Reference in New Issue
Block a user