[FEATURE] Add PLVQ model
This commit is contained in:
parent
fc11d78b38
commit
c87ed5ba8b
@ -31,7 +31,9 @@ be available for use in your Python environment as `prototorch.models`.
|
|||||||
- Learning Vector Quantization Multi-Layer Network (LVQMLN)
|
- Learning Vector Quantization Multi-Layer Network (LVQMLN)
|
||||||
- Siamese GLVQ
|
- Siamese GLVQ
|
||||||
- Cross-Entropy Learning Vector Quantization (CELVQ)
|
- Cross-Entropy Learning Vector Quantization (CELVQ)
|
||||||
|
- Soft Learning Vector Quantization (SLVQ)
|
||||||
- Robust Soft Learning Vector Quantization (RSLVQ)
|
- Robust Soft Learning Vector Quantization (RSLVQ)
|
||||||
|
- Probabilistic Learning Vector Quantization (PLVQ)
|
||||||
|
|
||||||
### Other
|
### Other
|
||||||
|
|
||||||
@ -49,7 +51,6 @@ be available for use in your Python environment as `prototorch.models`.
|
|||||||
|
|
||||||
- Median-LVQ
|
- Median-LVQ
|
||||||
- Generalized Tangent Learning Vector Quantization (GTLVQ)
|
- Generalized Tangent Learning Vector Quantization (GTLVQ)
|
||||||
- Probabilistic Learning Vector Quantization (PLVQ)
|
|
||||||
- Self-Incremental Learning Vector Quantization (SILVQ)
|
- Self-Incremental Learning Vector Quantization (SILVQ)
|
||||||
|
|
||||||
## Development setup
|
## Development setup
|
||||||
|
@ -2,8 +2,6 @@
|
|||||||
|
|
||||||
Abstract Models
|
Abstract Models
|
||||||
========================================
|
========================================
|
||||||
.. autoclass:: prototorch.models.abstract.AbstractPrototypeModel
|
.. automodule:: prototorch.models.abstract
|
||||||
:members:
|
:members:
|
||||||
|
:undoc-members:
|
||||||
.. autoclass:: prototorch.models.abstract.PrototypeImageModel
|
|
||||||
:members:
|
|
@ -8,7 +8,7 @@ Models
|
|||||||
|
|
||||||
Unsupervised Methods
|
Unsupervised Methods
|
||||||
-----------------------------------------
|
-----------------------------------------
|
||||||
.. autoclass:: prototorch.models.unsupervised.KNN
|
.. autoclass:: prototorch.models.knn.KNN
|
||||||
:members:
|
:members:
|
||||||
|
|
||||||
.. autoclass:: prototorch.models.unsupervised.NeuralGas
|
.. autoclass:: prototorch.models.unsupervised.NeuralGas
|
||||||
@ -80,9 +80,11 @@ Every prototypes is a center of a gaussian distribution of its class, generating
|
|||||||
.. autoclass:: prototorch.models.probabilistic.RSLVQ
|
.. autoclass:: prototorch.models.probabilistic.RSLVQ
|
||||||
:members:
|
:members:
|
||||||
|
|
||||||
Missing:
|
:cite:t:`villmann2018` proposed two changes to RSLVQ: First incooperate the winning rank into the prior probability calculation.
|
||||||
|
And second use divergence as loss function.
|
||||||
|
|
||||||
- PLVQ
|
.. autoclass:: prototorch.models.probabilistic.PLVQ
|
||||||
|
:members:
|
||||||
|
|
||||||
Classification by Component
|
Classification by Component
|
||||||
--------------------------------------------
|
--------------------------------------------
|
||||||
|
@ -60,3 +60,14 @@
|
|||||||
doi = {10.1162/neco.2009.11-08-908},
|
doi = {10.1162/neco.2009.11-08-908},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@InProceedings{villmann2018,
|
||||||
|
author="Villmann, Andrea
|
||||||
|
and Kaden, Marika
|
||||||
|
and Saralajew, Sascha
|
||||||
|
and Villmann, Thomas",
|
||||||
|
title="Probabilistic Learning Vector Quantization with Cross-Entropy for Probabilistic Class Assignments in Classification Learning",
|
||||||
|
booktitle="Artificial Intelligence and Soft Computing",
|
||||||
|
year="2018",
|
||||||
|
publisher="Springer International Publishing",
|
||||||
|
}
|
||||||
|
|
||||||
|
@ -5,6 +5,7 @@ import argparse
|
|||||||
import prototorch as pt
|
import prototorch as pt
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import torch
|
import torch
|
||||||
|
from torchvision.transforms import Lambda
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# Command-line arguments
|
# Command-line arguments
|
||||||
@ -24,12 +25,15 @@ if __name__ == "__main__":
|
|||||||
# Hyperparameters
|
# Hyperparameters
|
||||||
hparams = dict(
|
hparams = dict(
|
||||||
distribution=[2, 2, 3],
|
distribution=[2, 2, 3],
|
||||||
lr=0.05,
|
proto_lr=0.05,
|
||||||
variance=0.1,
|
lambd=0.1,
|
||||||
|
input_dim=2,
|
||||||
|
latent_dim=2,
|
||||||
|
bb_lr=0.01,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Initialize the model
|
# Initialize the model
|
||||||
model = pt.models.probabilistic.RSLVQ(
|
model = pt.models.probabilistic.PLVQ(
|
||||||
hparams,
|
hparams,
|
||||||
optimizer=torch.optim.Adam,
|
optimizer=torch.optim.Adam,
|
||||||
# prototype_initializer=pt.components.SMI(train_ds),
|
# prototype_initializer=pt.components.SMI(train_ds),
|
||||||
@ -45,7 +49,7 @@ if __name__ == "__main__":
|
|||||||
print(model)
|
print(model)
|
||||||
|
|
||||||
# Callbacks
|
# Callbacks
|
||||||
vis = pt.models.VisGLVQ2D(data=train_ds)
|
vis = pt.models.VisSiameseGLVQ2D(data=train_ds)
|
||||||
|
|
||||||
# Setup trainer
|
# Setup trainer
|
||||||
trainer = pl.Trainer.from_argparse_args(
|
trainer = pl.Trainer.from_argparse_args(
|
||||||
|
@ -4,22 +4,11 @@ from importlib.metadata import PackageNotFoundError, version
|
|||||||
|
|
||||||
from .callbacks import PrototypeConvergence, PruneLoserPrototypes
|
from .callbacks import PrototypeConvergence, PruneLoserPrototypes
|
||||||
from .cbc import CBC, ImageCBC
|
from .cbc import CBC, ImageCBC
|
||||||
from .glvq import (
|
from .glvq import (GLVQ, GLVQ1, GLVQ21, GMLVQ, GRLVQ, LGMLVQ, LVQMLN,
|
||||||
GLVQ,
|
ImageGLVQ, ImageGMLVQ, SiameseGLVQ, SiameseGMLVQ)
|
||||||
GLVQ1,
|
|
||||||
GLVQ21,
|
|
||||||
GMLVQ,
|
|
||||||
GRLVQ,
|
|
||||||
LGMLVQ,
|
|
||||||
LVQMLN,
|
|
||||||
ImageGLVQ,
|
|
||||||
ImageGMLVQ,
|
|
||||||
SiameseGLVQ,
|
|
||||||
SiameseGMLVQ,
|
|
||||||
)
|
|
||||||
from .knn import KNN
|
from .knn import KNN
|
||||||
from .lvq import LVQ1, LVQ21, MedianLVQ
|
from .lvq import LVQ1, LVQ21, MedianLVQ
|
||||||
from .probabilistic import CELVQ, RSLVQ, SLVQ
|
from .probabilistic import CELVQ, PLVQ, RSLVQ, SLVQ
|
||||||
from .unsupervised import GrowingNeuralGas, HeskesSOM, KohonenSOM, NeuralGas
|
from .unsupervised import GrowingNeuralGas, HeskesSOM, KohonenSOM, NeuralGas
|
||||||
from .vis import *
|
from .vis import *
|
||||||
|
|
||||||
|
@ -2,11 +2,13 @@
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from prototorch.functions.losses import nllr_loss, rslvq_loss
|
from prototorch.functions.losses import nllr_loss, rslvq_loss
|
||||||
from prototorch.functions.pooling import stratified_min_pooling, stratified_sum_pooling
|
from prototorch.functions.pooling import (stratified_min_pooling,
|
||||||
from prototorch.functions.transforms import gaussian
|
stratified_sum_pooling)
|
||||||
|
from prototorch.functions.transforms import (GaussianPrior,
|
||||||
|
RankScaledGaussianPrior)
|
||||||
from prototorch.modules import LambdaLayer, LossLayer
|
from prototorch.modules import LambdaLayer, LossLayer
|
||||||
|
|
||||||
from .glvq import GLVQ
|
from .glvq import GLVQ, SiameseGMLVQ
|
||||||
|
|
||||||
|
|
||||||
class CELVQ(GLVQ):
|
class CELVQ(GLVQ):
|
||||||
@ -32,13 +34,12 @@ class ProbabilisticLVQ(GLVQ):
|
|||||||
def __init__(self, hparams, rejection_confidence=0.0, **kwargs):
|
def __init__(self, hparams, rejection_confidence=0.0, **kwargs):
|
||||||
super().__init__(hparams, **kwargs)
|
super().__init__(hparams, **kwargs)
|
||||||
|
|
||||||
self.conditional_distribution = gaussian
|
self.conditional_distribution = None
|
||||||
self.rejection_confidence = rejection_confidence
|
self.rejection_confidence = rejection_confidence
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
distances = self.compute_distances(x)
|
distances = self.compute_distances(x)
|
||||||
conditional = self.conditional_distribution(distances,
|
conditional = self.conditional_distribution(distances)
|
||||||
self.hparams.variance)
|
|
||||||
prior = (1. / self.num_prototypes) * torch.ones(self.num_prototypes,
|
prior = (1. / self.num_prototypes) * torch.ones(self.num_prototypes,
|
||||||
device=self.device)
|
device=self.device)
|
||||||
posterior = conditional * prior
|
posterior = conditional * prior
|
||||||
@ -66,6 +67,7 @@ class SLVQ(ProbabilisticLVQ):
|
|||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self.loss = LossLayer(nllr_loss)
|
self.loss = LossLayer(nllr_loss)
|
||||||
|
self.conditional_distribution = GaussianPrior(self.hparams.variance)
|
||||||
|
|
||||||
|
|
||||||
class RSLVQ(ProbabilisticLVQ):
|
class RSLVQ(ProbabilisticLVQ):
|
||||||
@ -73,3 +75,25 @@ class RSLVQ(ProbabilisticLVQ):
|
|||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self.loss = LossLayer(rslvq_loss)
|
self.loss = LossLayer(rslvq_loss)
|
||||||
|
self.conditional_distribution = GaussianPrior(self.hparams.variance)
|
||||||
|
|
||||||
|
|
||||||
|
class PLVQ(ProbabilisticLVQ, SiameseGMLVQ):
|
||||||
|
"""Probabilistic Learning Vector Quantization.
|
||||||
|
|
||||||
|
TODO: Use Backbone LVQ instead
|
||||||
|
"""
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.conditional_distribution = RankScaledGaussianPrior(
|
||||||
|
self.hparams.lambd)
|
||||||
|
self.loss = torch.nn.KLDivLoss()
|
||||||
|
|
||||||
|
def training_step(self, batch, batch_idx, optimizer_idx=None):
|
||||||
|
x, y = batch
|
||||||
|
out = self.forward(x)
|
||||||
|
y_dist = torch.nn.functional.one_hot(
|
||||||
|
y.long(), num_classes=self.num_classes).float()
|
||||||
|
batch_loss = self.loss(out, y_dist)
|
||||||
|
loss = batch_loss.sum(dim=0)
|
||||||
|
return loss
|
||||||
|
Loading…
Reference in New Issue
Block a user