refactor: use GLVQLoss
instead of LossLayer
This commit is contained in:
parent
0f9f24e36a
commit
9d38123114
@ -66,7 +66,7 @@ if __name__ == "__main__":
|
|||||||
args,
|
args,
|
||||||
callbacks=[
|
callbacks=[
|
||||||
vis,
|
vis,
|
||||||
# es, # FIXME
|
es,
|
||||||
pruning,
|
pruning,
|
||||||
],
|
],
|
||||||
terminate_on_nan=True,
|
terminate_on_nan=True,
|
||||||
|
@ -6,9 +6,8 @@ from torch.nn.parameter import Parameter
|
|||||||
from ..core.competitions import wtac
|
from ..core.competitions import wtac
|
||||||
from ..core.distances import lomega_distance, omega_distance, squared_euclidean_distance
|
from ..core.distances import lomega_distance, omega_distance, squared_euclidean_distance
|
||||||
from ..core.initializers import EyeTransformInitializer
|
from ..core.initializers import EyeTransformInitializer
|
||||||
from ..core.losses import glvq_loss, lvq1_loss, lvq21_loss
|
from ..core.losses import GLVQLoss, lvq1_loss, lvq21_loss
|
||||||
from ..core.transforms import LinearTransform
|
from ..core.transforms import LinearTransform
|
||||||
from ..nn.activations import get_activation
|
|
||||||
from ..nn.wrappers import LambdaLayer, LossLayer
|
from ..nn.wrappers import LambdaLayer, LossLayer
|
||||||
from .abstract import ImagePrototypesMixin, SupervisedPrototypeModel
|
from .abstract import ImagePrototypesMixin, SupervisedPrototypeModel
|
||||||
|
|
||||||
@ -19,15 +18,16 @@ class GLVQ(SupervisedPrototypeModel):
|
|||||||
super().__init__(hparams, **kwargs)
|
super().__init__(hparams, **kwargs)
|
||||||
|
|
||||||
# Default hparams
|
# Default hparams
|
||||||
|
self.hparams.setdefault("margin", 0.0)
|
||||||
self.hparams.setdefault("transfer_fn", "identity")
|
self.hparams.setdefault("transfer_fn", "identity")
|
||||||
self.hparams.setdefault("transfer_beta", 10.0)
|
self.hparams.setdefault("transfer_beta", 10.0)
|
||||||
|
|
||||||
# Layers
|
|
||||||
transfer_fn = get_activation(self.hparams.transfer_fn)
|
|
||||||
self.transfer_layer = LambdaLayer(transfer_fn)
|
|
||||||
|
|
||||||
# Loss
|
# Loss
|
||||||
self.loss = LossLayer(glvq_loss)
|
self.loss = GLVQLoss(
|
||||||
|
margin=self.hparams.margin,
|
||||||
|
transfer_fn=self.hparams.transfer_fn,
|
||||||
|
beta=self.hparams.transfer_beta,
|
||||||
|
)
|
||||||
|
|
||||||
def initialize_prototype_win_ratios(self):
|
def initialize_prototype_win_ratios(self):
|
||||||
self.register_buffer(
|
self.register_buffer(
|
||||||
@ -56,9 +56,7 @@ class GLVQ(SupervisedPrototypeModel):
|
|||||||
x, y = batch
|
x, y = batch
|
||||||
out = self.compute_distances(x)
|
out = self.compute_distances(x)
|
||||||
plabels = self.proto_layer.labels
|
plabels = self.proto_layer.labels
|
||||||
mu = self.loss(out, y, prototype_labels=plabels)
|
loss = self.loss(out, y, plabels)
|
||||||
batch_loss = self.transfer_layer(mu, beta=self.hparams.transfer_beta)
|
|
||||||
loss = batch_loss.sum()
|
|
||||||
return out, loss
|
return out, loss
|
||||||
|
|
||||||
def training_step(self, batch, batch_idx, optimizer_idx=None):
|
def training_step(self, batch, batch_idx, optimizer_idx=None):
|
||||||
|
Loading…
Reference in New Issue
Block a user