diff --git a/prototorch/models/__init__.py b/prototorch/models/__init__.py index bd8297e..6666df7 100644 --- a/prototorch/models/__init__.py +++ b/prototorch/models/__init__.py @@ -2,7 +2,7 @@ from importlib.metadata import PackageNotFoundError, version from . import probabilistic from .cbc import CBC, ImageCBC -from .glvq import (GLVQ, GLVQ1, GLVQ21, GMLVQ, GRLVQ, LVQMLN, ImageGLVQ, +from .glvq import (CELVQ, GLVQ, GLVQ1, GLVQ21, GMLVQ, GRLVQ, LVQMLN, ImageGLVQ, ImageGMLVQ, SiameseGLVQ) from .lvq import LVQ1, LVQ21, MedianLVQ from .unsupervised import KNN, NeuralGas diff --git a/prototorch/models/glvq.py b/prototorch/models/glvq.py index d681d28..2b5f925 100644 --- a/prototorch/models/glvq.py +++ b/prototorch/models/glvq.py @@ -4,11 +4,11 @@ import torch import torchmetrics from prototorch.components import LabeledComponents from prototorch.functions.activations import get_activation -from prototorch.functions.competitions import wtac +from prototorch.functions.competitions import stratified_min, wtac from prototorch.functions.distances import (euclidean_distance, omega_distance, sed) from prototorch.functions.helper import get_flat -from prototorch.functions.losses import (glvq_loss, lvq1_loss, lvq21_loss) +from prototorch.functions.losses import glvq_loss, lvq1_loss, lvq21_loss from .abstract import AbstractPrototypeModel, PrototypeImageModel @@ -260,6 +260,22 @@ class LVQMLN(SiameseGLVQ): return distances +class CELVQ(GLVQ): + """Cross-Entropy Learning Vector Quantization.""" + def __init__(self, hparams, **kwargs): + super().__init__(hparams, **kwargs) + self.loss = torch.nn.CrossEntropyLoss() + + def shared_step(self, batch, batch_idx, optimizer_idx=None): + x, y = batch + out = self._forward(x) # [None, num_protos] + plabels = self.proto_layer.component_labels + probs = -1.0 * stratified_min(out, plabels) # [None, num_classes] + batch_loss = self.loss(out, y.long()) + loss = batch_loss.sum(dim=0) + return out, loss + + class GLVQ1(GLVQ): """Generalized Learning Vector Quantization 1.""" def __init__(self, hparams, **kwargs):