refactor: use GLVQLoss instead of LossLayer

This commit is contained in:
Jensun Ravichandran 2021-07-06 17:09:21 +02:00
parent 0f9f24e36a
commit 9d38123114
No known key found for this signature in database
GPG Key ID: 3331B0F18B6D4D93
2 changed files with 9 additions and 11 deletions

View File

@ -66,7 +66,7 @@ if __name__ == "__main__":
args,
callbacks=[
vis,
# es, # FIXME
es,
pruning,
],
terminate_on_nan=True,

View File

@ -6,9 +6,8 @@ from torch.nn.parameter import Parameter
from ..core.competitions import wtac
from ..core.distances import lomega_distance, omega_distance, squared_euclidean_distance
from ..core.initializers import EyeTransformInitializer
from ..core.losses import glvq_loss, lvq1_loss, lvq21_loss
from ..core.losses import GLVQLoss, lvq1_loss, lvq21_loss
from ..core.transforms import LinearTransform
from ..nn.activations import get_activation
from ..nn.wrappers import LambdaLayer, LossLayer
from .abstract import ImagePrototypesMixin, SupervisedPrototypeModel
@ -19,15 +18,16 @@ class GLVQ(SupervisedPrototypeModel):
super().__init__(hparams, **kwargs)
# Default hparams
self.hparams.setdefault("margin", 0.0)
self.hparams.setdefault("transfer_fn", "identity")
self.hparams.setdefault("transfer_beta", 10.0)
# Layers
transfer_fn = get_activation(self.hparams.transfer_fn)
self.transfer_layer = LambdaLayer(transfer_fn)
# Loss
self.loss = LossLayer(glvq_loss)
self.loss = GLVQLoss(
margin=self.hparams.margin,
transfer_fn=self.hparams.transfer_fn,
beta=self.hparams.transfer_beta,
)
def initialize_prototype_win_ratios(self):
self.register_buffer(
@ -56,9 +56,7 @@ class GLVQ(SupervisedPrototypeModel):
x, y = batch
out = self.compute_distances(x)
plabels = self.proto_layer.labels
mu = self.loss(out, y, prototype_labels=plabels)
batch_loss = self.transfer_layer(mu, beta=self.hparams.transfer_beta)
loss = batch_loss.sum()
loss = self.loss(out, y, plabels)
return out, loss
def training_step(self, batch, batch_idx, optimizer_idx=None):