[WIP] Add CELVQ

TODO Ensure that the distances/probs corresponding to the plabels are sorted
like the target labels.
This commit is contained in:
Jensun Ravichandran 2021-05-27 17:40:16 +02:00
parent 41b2a2f496
commit b7edee02c3
2 changed files with 19 additions and 3 deletions

View File

@ -2,7 +2,7 @@ from importlib.metadata import PackageNotFoundError, version
from . import probabilistic
from .cbc import CBC, ImageCBC
from .glvq import (GLVQ, GLVQ1, GLVQ21, GMLVQ, GRLVQ, LVQMLN, ImageGLVQ,
from .glvq import (CELVQ, GLVQ, GLVQ1, GLVQ21, GMLVQ, GRLVQ, LVQMLN, ImageGLVQ,
ImageGMLVQ, SiameseGLVQ)
from .lvq import LVQ1, LVQ21, MedianLVQ
from .unsupervised import KNN, NeuralGas

View File

@ -4,11 +4,11 @@ import torch
import torchmetrics
from prototorch.components import LabeledComponents
from prototorch.functions.activations import get_activation
from prototorch.functions.competitions import wtac
from prototorch.functions.competitions import stratified_min, wtac
from prototorch.functions.distances import (euclidean_distance, omega_distance,
sed)
from prototorch.functions.helper import get_flat
from prototorch.functions.losses import (glvq_loss, lvq1_loss, lvq21_loss)
from prototorch.functions.losses import glvq_loss, lvq1_loss, lvq21_loss
from .abstract import AbstractPrototypeModel, PrototypeImageModel
@ -260,6 +260,22 @@ class LVQMLN(SiameseGLVQ):
return distances
class CELVQ(GLVQ):
"""Cross-Entropy Learning Vector Quantization."""
def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs)
self.loss = torch.nn.CrossEntropyLoss()
def shared_step(self, batch, batch_idx, optimizer_idx=None):
x, y = batch
out = self._forward(x) # [None, num_protos]
plabels = self.proto_layer.component_labels
probs = -1.0 * stratified_min(out, plabels) # [None, num_classes]
batch_loss = self.loss(out, y.long())
loss = batch_loss.sum(dim=0)
return out, loss
class GLVQ1(GLVQ):
"""Generalized Learning Vector Quantization 1."""
def __init__(self, hparams, **kwargs):