[FEATURE] Add PLVQ model

This commit is contained in:
Alexander Engelsberger
2021-06-08 15:01:08 +02:00
committed by Alexander Engelsberger
parent fc11d78b38
commit c87ed5ba8b
7 changed files with 61 additions and 32 deletions

View File

@@ -4,22 +4,11 @@ from importlib.metadata import PackageNotFoundError, version
from .callbacks import PrototypeConvergence, PruneLoserPrototypes
from .cbc import CBC, ImageCBC
from .glvq import (
GLVQ,
GLVQ1,
GLVQ21,
GMLVQ,
GRLVQ,
LGMLVQ,
LVQMLN,
ImageGLVQ,
ImageGMLVQ,
SiameseGLVQ,
SiameseGMLVQ,
)
from .glvq import (GLVQ, GLVQ1, GLVQ21, GMLVQ, GRLVQ, LGMLVQ, LVQMLN,
ImageGLVQ, ImageGMLVQ, SiameseGLVQ, SiameseGMLVQ)
from .knn import KNN
from .lvq import LVQ1, LVQ21, MedianLVQ
from .probabilistic import CELVQ, RSLVQ, SLVQ
from .probabilistic import CELVQ, PLVQ, RSLVQ, SLVQ
from .unsupervised import GrowingNeuralGas, HeskesSOM, KohonenSOM, NeuralGas
from .vis import *

View File

@@ -2,11 +2,13 @@
import torch
from prototorch.functions.losses import nllr_loss, rslvq_loss
from prototorch.functions.pooling import stratified_min_pooling, stratified_sum_pooling
from prototorch.functions.transforms import gaussian
from prototorch.functions.pooling import (stratified_min_pooling,
stratified_sum_pooling)
from prototorch.functions.transforms import (GaussianPrior,
RankScaledGaussianPrior)
from prototorch.modules import LambdaLayer, LossLayer
from .glvq import GLVQ
from .glvq import GLVQ, SiameseGMLVQ
class CELVQ(GLVQ):
@@ -32,13 +34,12 @@ class ProbabilisticLVQ(GLVQ):
def __init__(self, hparams, rejection_confidence=0.0, **kwargs):
super().__init__(hparams, **kwargs)
self.conditional_distribution = gaussian
self.conditional_distribution = None
self.rejection_confidence = rejection_confidence
def forward(self, x):
distances = self.compute_distances(x)
conditional = self.conditional_distribution(distances,
self.hparams.variance)
conditional = self.conditional_distribution(distances)
prior = (1. / self.num_prototypes) * torch.ones(self.num_prototypes,
device=self.device)
posterior = conditional * prior
@@ -66,6 +67,7 @@ class SLVQ(ProbabilisticLVQ):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.loss = LossLayer(nllr_loss)
self.conditional_distribution = GaussianPrior(self.hparams.variance)
class RSLVQ(ProbabilisticLVQ):
@@ -73,3 +75,25 @@ class RSLVQ(ProbabilisticLVQ):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.loss = LossLayer(rslvq_loss)
self.conditional_distribution = GaussianPrior(self.hparams.variance)
class PLVQ(ProbabilisticLVQ, SiameseGMLVQ):
"""Probabilistic Learning Vector Quantization.
TODO: Use Backbone LVQ instead
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.conditional_distribution = RankScaledGaussianPrior(
self.hparams.lambd)
self.loss = torch.nn.KLDivLoss()
def training_step(self, batch, batch_idx, optimizer_idx=None):
x, y = batch
out = self.forward(x)
y_dist = torch.nn.functional.one_hot(
y.long(), num_classes=self.num_classes).float()
batch_loss = self.loss(out, y_dist)
loss = batch_loss.sum(dim=0)
return loss