[FEATURE] Add PLVQ model
This commit is contained in:
committed by
Alexander Engelsberger
parent
fc11d78b38
commit
c87ed5ba8b
@@ -4,22 +4,11 @@ from importlib.metadata import PackageNotFoundError, version
|
||||
|
||||
from .callbacks import PrototypeConvergence, PruneLoserPrototypes
|
||||
from .cbc import CBC, ImageCBC
|
||||
from .glvq import (
|
||||
GLVQ,
|
||||
GLVQ1,
|
||||
GLVQ21,
|
||||
GMLVQ,
|
||||
GRLVQ,
|
||||
LGMLVQ,
|
||||
LVQMLN,
|
||||
ImageGLVQ,
|
||||
ImageGMLVQ,
|
||||
SiameseGLVQ,
|
||||
SiameseGMLVQ,
|
||||
)
|
||||
from .glvq import (GLVQ, GLVQ1, GLVQ21, GMLVQ, GRLVQ, LGMLVQ, LVQMLN,
|
||||
ImageGLVQ, ImageGMLVQ, SiameseGLVQ, SiameseGMLVQ)
|
||||
from .knn import KNN
|
||||
from .lvq import LVQ1, LVQ21, MedianLVQ
|
||||
from .probabilistic import CELVQ, RSLVQ, SLVQ
|
||||
from .probabilistic import CELVQ, PLVQ, RSLVQ, SLVQ
|
||||
from .unsupervised import GrowingNeuralGas, HeskesSOM, KohonenSOM, NeuralGas
|
||||
from .vis import *
|
||||
|
||||
|
@@ -2,11 +2,13 @@
|
||||
|
||||
import torch
|
||||
from prototorch.functions.losses import nllr_loss, rslvq_loss
|
||||
from prototorch.functions.pooling import stratified_min_pooling, stratified_sum_pooling
|
||||
from prototorch.functions.transforms import gaussian
|
||||
from prototorch.functions.pooling import (stratified_min_pooling,
|
||||
stratified_sum_pooling)
|
||||
from prototorch.functions.transforms import (GaussianPrior,
|
||||
RankScaledGaussianPrior)
|
||||
from prototorch.modules import LambdaLayer, LossLayer
|
||||
|
||||
from .glvq import GLVQ
|
||||
from .glvq import GLVQ, SiameseGMLVQ
|
||||
|
||||
|
||||
class CELVQ(GLVQ):
|
||||
@@ -32,13 +34,12 @@ class ProbabilisticLVQ(GLVQ):
|
||||
def __init__(self, hparams, rejection_confidence=0.0, **kwargs):
|
||||
super().__init__(hparams, **kwargs)
|
||||
|
||||
self.conditional_distribution = gaussian
|
||||
self.conditional_distribution = None
|
||||
self.rejection_confidence = rejection_confidence
|
||||
|
||||
def forward(self, x):
|
||||
distances = self.compute_distances(x)
|
||||
conditional = self.conditional_distribution(distances,
|
||||
self.hparams.variance)
|
||||
conditional = self.conditional_distribution(distances)
|
||||
prior = (1. / self.num_prototypes) * torch.ones(self.num_prototypes,
|
||||
device=self.device)
|
||||
posterior = conditional * prior
|
||||
@@ -66,6 +67,7 @@ class SLVQ(ProbabilisticLVQ):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.loss = LossLayer(nllr_loss)
|
||||
self.conditional_distribution = GaussianPrior(self.hparams.variance)
|
||||
|
||||
|
||||
class RSLVQ(ProbabilisticLVQ):
|
||||
@@ -73,3 +75,25 @@ class RSLVQ(ProbabilisticLVQ):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.loss = LossLayer(rslvq_loss)
|
||||
self.conditional_distribution = GaussianPrior(self.hparams.variance)
|
||||
|
||||
|
||||
class PLVQ(ProbabilisticLVQ, SiameseGMLVQ):
|
||||
"""Probabilistic Learning Vector Quantization.
|
||||
|
||||
TODO: Use Backbone LVQ instead
|
||||
"""
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.conditional_distribution = RankScaledGaussianPrior(
|
||||
self.hparams.lambd)
|
||||
self.loss = torch.nn.KLDivLoss()
|
||||
|
||||
def training_step(self, batch, batch_idx, optimizer_idx=None):
|
||||
x, y = batch
|
||||
out = self.forward(x)
|
||||
y_dist = torch.nn.functional.one_hot(
|
||||
y.long(), num_classes=self.num_classes).float()
|
||||
batch_loss = self.loss(out, y_dist)
|
||||
loss = batch_loss.sum(dim=0)
|
||||
return loss
|
||||
|
Reference in New Issue
Block a user