[FEATURE] Add PLVQ model

This commit is contained in:
Alexander Engelsberger 2021-06-08 15:01:08 +02:00 committed by Alexander Engelsberger
parent fc11d78b38
commit c87ed5ba8b
No known key found for this signature in database
GPG Key ID: BE3F5909FF0D83E3
7 changed files with 61 additions and 32 deletions

View File

@ -31,7 +31,9 @@ be available for use in your Python environment as `prototorch.models`.
- Learning Vector Quantization Multi-Layer Network (LVQMLN)
- Siamese GLVQ
- Cross-Entropy Learning Vector Quantization (CELVQ)
- Soft Learning Vector Quantization (SLVQ)
- Robust Soft Learning Vector Quantization (RSLVQ)
- Probabilistic Learning Vector Quantization (PLVQ)
### Other
@ -49,7 +51,6 @@ be available for use in your Python environment as `prototorch.models`.
- Median-LVQ
- Generalized Tangent Learning Vector Quantization (GTLVQ)
- Probabilistic Learning Vector Quantization (PLVQ)
- Self-Incremental Learning Vector Quantization (SILVQ)
## Development setup

View File

@ -2,8 +2,6 @@
Abstract Models
========================================
.. autoclass:: prototorch.models.abstract.AbstractPrototypeModel
.. automodule:: prototorch.models.abstract
:members:
.. autoclass:: prototorch.models.abstract.PrototypeImageModel
:members:
:undoc-members:

View File

@ -8,7 +8,7 @@ Models
Unsupervised Methods
-----------------------------------------
.. autoclass:: prototorch.models.unsupervised.KNN
.. autoclass:: prototorch.models.knn.KNN
:members:
.. autoclass:: prototorch.models.unsupervised.NeuralGas
@ -80,9 +80,11 @@ Every prototypes is a center of a gaussian distribution of its class, generating
.. autoclass:: prototorch.models.probabilistic.RSLVQ
:members:
Missing:
:cite:t:`villmann2018` proposed two changes to RSLVQ: First incooperate the winning rank into the prior probability calculation.
And second use divergence as loss function.
- PLVQ
.. autoclass:: prototorch.models.probabilistic.PLVQ
:members:
Classification by Component
--------------------------------------------

View File

@ -60,3 +60,14 @@
doi = {10.1162/neco.2009.11-08-908},
}
@InProceedings{villmann2018,
author="Villmann, Andrea
and Kaden, Marika
and Saralajew, Sascha
and Villmann, Thomas",
title="Probabilistic Learning Vector Quantization with Cross-Entropy for Probabilistic Class Assignments in Classification Learning",
booktitle="Artificial Intelligence and Soft Computing",
year="2018",
publisher="Springer International Publishing",
}

View File

@ -5,6 +5,7 @@ import argparse
import prototorch as pt
import pytorch_lightning as pl
import torch
from torchvision.transforms import Lambda
if __name__ == "__main__":
# Command-line arguments
@ -24,12 +25,15 @@ if __name__ == "__main__":
# Hyperparameters
hparams = dict(
distribution=[2, 2, 3],
lr=0.05,
variance=0.1,
proto_lr=0.05,
lambd=0.1,
input_dim=2,
latent_dim=2,
bb_lr=0.01,
)
# Initialize the model
model = pt.models.probabilistic.RSLVQ(
model = pt.models.probabilistic.PLVQ(
hparams,
optimizer=torch.optim.Adam,
# prototype_initializer=pt.components.SMI(train_ds),
@ -45,7 +49,7 @@ if __name__ == "__main__":
print(model)
# Callbacks
vis = pt.models.VisGLVQ2D(data=train_ds)
vis = pt.models.VisSiameseGLVQ2D(data=train_ds)
# Setup trainer
trainer = pl.Trainer.from_argparse_args(

View File

@ -4,22 +4,11 @@ from importlib.metadata import PackageNotFoundError, version
from .callbacks import PrototypeConvergence, PruneLoserPrototypes
from .cbc import CBC, ImageCBC
from .glvq import (
GLVQ,
GLVQ1,
GLVQ21,
GMLVQ,
GRLVQ,
LGMLVQ,
LVQMLN,
ImageGLVQ,
ImageGMLVQ,
SiameseGLVQ,
SiameseGMLVQ,
)
from .glvq import (GLVQ, GLVQ1, GLVQ21, GMLVQ, GRLVQ, LGMLVQ, LVQMLN,
ImageGLVQ, ImageGMLVQ, SiameseGLVQ, SiameseGMLVQ)
from .knn import KNN
from .lvq import LVQ1, LVQ21, MedianLVQ
from .probabilistic import CELVQ, RSLVQ, SLVQ
from .probabilistic import CELVQ, PLVQ, RSLVQ, SLVQ
from .unsupervised import GrowingNeuralGas, HeskesSOM, KohonenSOM, NeuralGas
from .vis import *

View File

@ -2,11 +2,13 @@
import torch
from prototorch.functions.losses import nllr_loss, rslvq_loss
from prototorch.functions.pooling import stratified_min_pooling, stratified_sum_pooling
from prototorch.functions.transforms import gaussian
from prototorch.functions.pooling import (stratified_min_pooling,
stratified_sum_pooling)
from prototorch.functions.transforms import (GaussianPrior,
RankScaledGaussianPrior)
from prototorch.modules import LambdaLayer, LossLayer
from .glvq import GLVQ
from .glvq import GLVQ, SiameseGMLVQ
class CELVQ(GLVQ):
@ -32,13 +34,12 @@ class ProbabilisticLVQ(GLVQ):
def __init__(self, hparams, rejection_confidence=0.0, **kwargs):
super().__init__(hparams, **kwargs)
self.conditional_distribution = gaussian
self.conditional_distribution = None
self.rejection_confidence = rejection_confidence
def forward(self, x):
distances = self.compute_distances(x)
conditional = self.conditional_distribution(distances,
self.hparams.variance)
conditional = self.conditional_distribution(distances)
prior = (1. / self.num_prototypes) * torch.ones(self.num_prototypes,
device=self.device)
posterior = conditional * prior
@ -66,6 +67,7 @@ class SLVQ(ProbabilisticLVQ):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.loss = LossLayer(nllr_loss)
self.conditional_distribution = GaussianPrior(self.hparams.variance)
class RSLVQ(ProbabilisticLVQ):
@ -73,3 +75,25 @@ class RSLVQ(ProbabilisticLVQ):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.loss = LossLayer(rslvq_loss)
self.conditional_distribution = GaussianPrior(self.hparams.variance)
class PLVQ(ProbabilisticLVQ, SiameseGMLVQ):
"""Probabilistic Learning Vector Quantization.
TODO: Use Backbone LVQ instead
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.conditional_distribution = RankScaledGaussianPrior(
self.hparams.lambd)
self.loss = torch.nn.KLDivLoss()
def training_step(self, batch, batch_idx, optimizer_idx=None):
x, y = batch
out = self.forward(x)
y_dist = torch.nn.functional.one_hot(
y.long(), num_classes=self.num_classes).float()
batch_loss = self.loss(out, y_dist)
loss = batch_loss.sum(dim=0)
return loss