[FEATURE] Add PLVQ model
This commit is contained in:
parent
fc11d78b38
commit
c87ed5ba8b
@ -31,7 +31,9 @@ be available for use in your Python environment as `prototorch.models`.
|
||||
- Learning Vector Quantization Multi-Layer Network (LVQMLN)
|
||||
- Siamese GLVQ
|
||||
- Cross-Entropy Learning Vector Quantization (CELVQ)
|
||||
- Soft Learning Vector Quantization (SLVQ)
|
||||
- Robust Soft Learning Vector Quantization (RSLVQ)
|
||||
- Probabilistic Learning Vector Quantization (PLVQ)
|
||||
|
||||
### Other
|
||||
|
||||
@ -49,7 +51,6 @@ be available for use in your Python environment as `prototorch.models`.
|
||||
|
||||
- Median-LVQ
|
||||
- Generalized Tangent Learning Vector Quantization (GTLVQ)
|
||||
- Probabilistic Learning Vector Quantization (PLVQ)
|
||||
- Self-Incremental Learning Vector Quantization (SILVQ)
|
||||
|
||||
## Development setup
|
||||
|
@ -2,8 +2,6 @@
|
||||
|
||||
Abstract Models
|
||||
========================================
|
||||
.. autoclass:: prototorch.models.abstract.AbstractPrototypeModel
|
||||
.. automodule:: prototorch.models.abstract
|
||||
:members:
|
||||
|
||||
.. autoclass:: prototorch.models.abstract.PrototypeImageModel
|
||||
:members:
|
||||
:undoc-members:
|
@ -8,7 +8,7 @@ Models
|
||||
|
||||
Unsupervised Methods
|
||||
-----------------------------------------
|
||||
.. autoclass:: prototorch.models.unsupervised.KNN
|
||||
.. autoclass:: prototorch.models.knn.KNN
|
||||
:members:
|
||||
|
||||
.. autoclass:: prototorch.models.unsupervised.NeuralGas
|
||||
@ -80,9 +80,11 @@ Every prototypes is a center of a gaussian distribution of its class, generating
|
||||
.. autoclass:: prototorch.models.probabilistic.RSLVQ
|
||||
:members:
|
||||
|
||||
Missing:
|
||||
:cite:t:`villmann2018` proposed two changes to RSLVQ: First incooperate the winning rank into the prior probability calculation.
|
||||
And second use divergence as loss function.
|
||||
|
||||
- PLVQ
|
||||
.. autoclass:: prototorch.models.probabilistic.PLVQ
|
||||
:members:
|
||||
|
||||
Classification by Component
|
||||
--------------------------------------------
|
||||
|
@ -60,3 +60,14 @@
|
||||
doi = {10.1162/neco.2009.11-08-908},
|
||||
}
|
||||
|
||||
@InProceedings{villmann2018,
|
||||
author="Villmann, Andrea
|
||||
and Kaden, Marika
|
||||
and Saralajew, Sascha
|
||||
and Villmann, Thomas",
|
||||
title="Probabilistic Learning Vector Quantization with Cross-Entropy for Probabilistic Class Assignments in Classification Learning",
|
||||
booktitle="Artificial Intelligence and Soft Computing",
|
||||
year="2018",
|
||||
publisher="Springer International Publishing",
|
||||
}
|
||||
|
||||
|
@ -5,6 +5,7 @@ import argparse
|
||||
import prototorch as pt
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
from torchvision.transforms import Lambda
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Command-line arguments
|
||||
@ -24,12 +25,15 @@ if __name__ == "__main__":
|
||||
# Hyperparameters
|
||||
hparams = dict(
|
||||
distribution=[2, 2, 3],
|
||||
lr=0.05,
|
||||
variance=0.1,
|
||||
proto_lr=0.05,
|
||||
lambd=0.1,
|
||||
input_dim=2,
|
||||
latent_dim=2,
|
||||
bb_lr=0.01,
|
||||
)
|
||||
|
||||
# Initialize the model
|
||||
model = pt.models.probabilistic.RSLVQ(
|
||||
model = pt.models.probabilistic.PLVQ(
|
||||
hparams,
|
||||
optimizer=torch.optim.Adam,
|
||||
# prototype_initializer=pt.components.SMI(train_ds),
|
||||
@ -45,7 +49,7 @@ if __name__ == "__main__":
|
||||
print(model)
|
||||
|
||||
# Callbacks
|
||||
vis = pt.models.VisGLVQ2D(data=train_ds)
|
||||
vis = pt.models.VisSiameseGLVQ2D(data=train_ds)
|
||||
|
||||
# Setup trainer
|
||||
trainer = pl.Trainer.from_argparse_args(
|
||||
|
@ -4,22 +4,11 @@ from importlib.metadata import PackageNotFoundError, version
|
||||
|
||||
from .callbacks import PrototypeConvergence, PruneLoserPrototypes
|
||||
from .cbc import CBC, ImageCBC
|
||||
from .glvq import (
|
||||
GLVQ,
|
||||
GLVQ1,
|
||||
GLVQ21,
|
||||
GMLVQ,
|
||||
GRLVQ,
|
||||
LGMLVQ,
|
||||
LVQMLN,
|
||||
ImageGLVQ,
|
||||
ImageGMLVQ,
|
||||
SiameseGLVQ,
|
||||
SiameseGMLVQ,
|
||||
)
|
||||
from .glvq import (GLVQ, GLVQ1, GLVQ21, GMLVQ, GRLVQ, LGMLVQ, LVQMLN,
|
||||
ImageGLVQ, ImageGMLVQ, SiameseGLVQ, SiameseGMLVQ)
|
||||
from .knn import KNN
|
||||
from .lvq import LVQ1, LVQ21, MedianLVQ
|
||||
from .probabilistic import CELVQ, RSLVQ, SLVQ
|
||||
from .probabilistic import CELVQ, PLVQ, RSLVQ, SLVQ
|
||||
from .unsupervised import GrowingNeuralGas, HeskesSOM, KohonenSOM, NeuralGas
|
||||
from .vis import *
|
||||
|
||||
|
@ -2,11 +2,13 @@
|
||||
|
||||
import torch
|
||||
from prototorch.functions.losses import nllr_loss, rslvq_loss
|
||||
from prototorch.functions.pooling import stratified_min_pooling, stratified_sum_pooling
|
||||
from prototorch.functions.transforms import gaussian
|
||||
from prototorch.functions.pooling import (stratified_min_pooling,
|
||||
stratified_sum_pooling)
|
||||
from prototorch.functions.transforms import (GaussianPrior,
|
||||
RankScaledGaussianPrior)
|
||||
from prototorch.modules import LambdaLayer, LossLayer
|
||||
|
||||
from .glvq import GLVQ
|
||||
from .glvq import GLVQ, SiameseGMLVQ
|
||||
|
||||
|
||||
class CELVQ(GLVQ):
|
||||
@ -32,13 +34,12 @@ class ProbabilisticLVQ(GLVQ):
|
||||
def __init__(self, hparams, rejection_confidence=0.0, **kwargs):
|
||||
super().__init__(hparams, **kwargs)
|
||||
|
||||
self.conditional_distribution = gaussian
|
||||
self.conditional_distribution = None
|
||||
self.rejection_confidence = rejection_confidence
|
||||
|
||||
def forward(self, x):
|
||||
distances = self.compute_distances(x)
|
||||
conditional = self.conditional_distribution(distances,
|
||||
self.hparams.variance)
|
||||
conditional = self.conditional_distribution(distances)
|
||||
prior = (1. / self.num_prototypes) * torch.ones(self.num_prototypes,
|
||||
device=self.device)
|
||||
posterior = conditional * prior
|
||||
@ -66,6 +67,7 @@ class SLVQ(ProbabilisticLVQ):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.loss = LossLayer(nllr_loss)
|
||||
self.conditional_distribution = GaussianPrior(self.hparams.variance)
|
||||
|
||||
|
||||
class RSLVQ(ProbabilisticLVQ):
|
||||
@ -73,3 +75,25 @@ class RSLVQ(ProbabilisticLVQ):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.loss = LossLayer(rslvq_loss)
|
||||
self.conditional_distribution = GaussianPrior(self.hparams.variance)
|
||||
|
||||
|
||||
class PLVQ(ProbabilisticLVQ, SiameseGMLVQ):
|
||||
"""Probabilistic Learning Vector Quantization.
|
||||
|
||||
TODO: Use Backbone LVQ instead
|
||||
"""
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.conditional_distribution = RankScaledGaussianPrior(
|
||||
self.hparams.lambd)
|
||||
self.loss = torch.nn.KLDivLoss()
|
||||
|
||||
def training_step(self, batch, batch_idx, optimizer_idx=None):
|
||||
x, y = batch
|
||||
out = self.forward(x)
|
||||
y_dist = torch.nn.functional.one_hot(
|
||||
y.long(), num_classes=self.num_classes).float()
|
||||
batch_loss = self.loss(out, y_dist)
|
||||
loss = batch_loss.sum(dim=0)
|
||||
return loss
|
||||
|
Loading…
Reference in New Issue
Block a user