refactor: use GLVQLoss
instead of LossLayer
This commit is contained in:
parent
0f9f24e36a
commit
9d38123114
@ -66,7 +66,7 @@ if __name__ == "__main__":
|
||||
args,
|
||||
callbacks=[
|
||||
vis,
|
||||
# es, # FIXME
|
||||
es,
|
||||
pruning,
|
||||
],
|
||||
terminate_on_nan=True,
|
||||
|
@ -6,9 +6,8 @@ from torch.nn.parameter import Parameter
|
||||
from ..core.competitions import wtac
|
||||
from ..core.distances import lomega_distance, omega_distance, squared_euclidean_distance
|
||||
from ..core.initializers import EyeTransformInitializer
|
||||
from ..core.losses import glvq_loss, lvq1_loss, lvq21_loss
|
||||
from ..core.losses import GLVQLoss, lvq1_loss, lvq21_loss
|
||||
from ..core.transforms import LinearTransform
|
||||
from ..nn.activations import get_activation
|
||||
from ..nn.wrappers import LambdaLayer, LossLayer
|
||||
from .abstract import ImagePrototypesMixin, SupervisedPrototypeModel
|
||||
|
||||
@ -19,15 +18,16 @@ class GLVQ(SupervisedPrototypeModel):
|
||||
super().__init__(hparams, **kwargs)
|
||||
|
||||
# Default hparams
|
||||
self.hparams.setdefault("margin", 0.0)
|
||||
self.hparams.setdefault("transfer_fn", "identity")
|
||||
self.hparams.setdefault("transfer_beta", 10.0)
|
||||
|
||||
# Layers
|
||||
transfer_fn = get_activation(self.hparams.transfer_fn)
|
||||
self.transfer_layer = LambdaLayer(transfer_fn)
|
||||
|
||||
# Loss
|
||||
self.loss = LossLayer(glvq_loss)
|
||||
self.loss = GLVQLoss(
|
||||
margin=self.hparams.margin,
|
||||
transfer_fn=self.hparams.transfer_fn,
|
||||
beta=self.hparams.transfer_beta,
|
||||
)
|
||||
|
||||
def initialize_prototype_win_ratios(self):
|
||||
self.register_buffer(
|
||||
@ -56,9 +56,7 @@ class GLVQ(SupervisedPrototypeModel):
|
||||
x, y = batch
|
||||
out = self.compute_distances(x)
|
||||
plabels = self.proto_layer.labels
|
||||
mu = self.loss(out, y, prototype_labels=plabels)
|
||||
batch_loss = self.transfer_layer(mu, beta=self.hparams.transfer_beta)
|
||||
loss = batch_loss.sum()
|
||||
loss = self.loss(out, y, plabels)
|
||||
return out, loss
|
||||
|
||||
def training_step(self, batch, batch_idx, optimizer_idx=None):
|
||||
|
Loading…
Reference in New Issue
Block a user