[WIP] Add CELVQ

TODO Ensure that the distances/probs corresponding to the plabels are sorted
like the target labels.
This commit is contained in:
Jensun Ravichandran 2021-05-27 17:40:16 +02:00
parent 41b2a2f496
commit b7edee02c3
2 changed files with 19 additions and 3 deletions

View File

@ -2,7 +2,7 @@ from importlib.metadata import PackageNotFoundError, version
from . import probabilistic from . import probabilistic
from .cbc import CBC, ImageCBC from .cbc import CBC, ImageCBC
from .glvq import (GLVQ, GLVQ1, GLVQ21, GMLVQ, GRLVQ, LVQMLN, ImageGLVQ, from .glvq import (CELVQ, GLVQ, GLVQ1, GLVQ21, GMLVQ, GRLVQ, LVQMLN, ImageGLVQ,
ImageGMLVQ, SiameseGLVQ) ImageGMLVQ, SiameseGLVQ)
from .lvq import LVQ1, LVQ21, MedianLVQ from .lvq import LVQ1, LVQ21, MedianLVQ
from .unsupervised import KNN, NeuralGas from .unsupervised import KNN, NeuralGas

View File

@ -4,11 +4,11 @@ import torch
import torchmetrics import torchmetrics
from prototorch.components import LabeledComponents from prototorch.components import LabeledComponents
from prototorch.functions.activations import get_activation from prototorch.functions.activations import get_activation
from prototorch.functions.competitions import wtac from prototorch.functions.competitions import stratified_min, wtac
from prototorch.functions.distances import (euclidean_distance, omega_distance, from prototorch.functions.distances import (euclidean_distance, omega_distance,
sed) sed)
from prototorch.functions.helper import get_flat from prototorch.functions.helper import get_flat
from prototorch.functions.losses import (glvq_loss, lvq1_loss, lvq21_loss) from prototorch.functions.losses import glvq_loss, lvq1_loss, lvq21_loss
from .abstract import AbstractPrototypeModel, PrototypeImageModel from .abstract import AbstractPrototypeModel, PrototypeImageModel
@ -260,6 +260,22 @@ class LVQMLN(SiameseGLVQ):
return distances return distances
class CELVQ(GLVQ):
"""Cross-Entropy Learning Vector Quantization."""
def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs)
self.loss = torch.nn.CrossEntropyLoss()
def shared_step(self, batch, batch_idx, optimizer_idx=None):
x, y = batch
out = self._forward(x) # [None, num_protos]
plabels = self.proto_layer.component_labels
probs = -1.0 * stratified_min(out, plabels) # [None, num_classes]
batch_loss = self.loss(out, y.long())
loss = batch_loss.sum(dim=0)
return out, loss
class GLVQ1(GLVQ): class GLVQ1(GLVQ):
"""Generalized Learning Vector Quantization 1.""" """Generalized Learning Vector Quantization 1."""
def __init__(self, hparams, **kwargs): def __init__(self, hparams, **kwargs):