[FEATURE] Add PLVQ model
This commit is contained in:
		
				
					committed by
					
						 Alexander Engelsberger
						Alexander Engelsberger
					
				
			
			
				
	
			
			
			
						parent
						
							fc11d78b38
						
					
				
				
					commit
					c87ed5ba8b
				
			| @@ -31,7 +31,9 @@ be available for use in your Python environment as `prototorch.models`. | |||||||
| - Learning Vector Quantization Multi-Layer Network (LVQMLN) | - Learning Vector Quantization Multi-Layer Network (LVQMLN) | ||||||
| - Siamese GLVQ | - Siamese GLVQ | ||||||
| - Cross-Entropy Learning Vector Quantization (CELVQ) | - Cross-Entropy Learning Vector Quantization (CELVQ) | ||||||
|  | - Soft Learning Vector Quantization (SLVQ) | ||||||
| - Robust Soft Learning Vector Quantization (RSLVQ) | - Robust Soft Learning Vector Quantization (RSLVQ) | ||||||
|  | - Probabilistic Learning Vector Quantization (PLVQ) | ||||||
|  |  | ||||||
| ### Other | ### Other | ||||||
|  |  | ||||||
| @@ -49,7 +51,6 @@ be available for use in your Python environment as `prototorch.models`. | |||||||
|  |  | ||||||
| - Median-LVQ | - Median-LVQ | ||||||
| - Generalized Tangent Learning Vector Quantization (GTLVQ) | - Generalized Tangent Learning Vector Quantization (GTLVQ) | ||||||
| - Probabilistic Learning Vector Quantization (PLVQ) |  | ||||||
| - Self-Incremental Learning Vector Quantization (SILVQ) | - Self-Incremental Learning Vector Quantization (SILVQ) | ||||||
|  |  | ||||||
| ## Development setup | ## Development setup | ||||||
|   | |||||||
| @@ -2,8 +2,6 @@ | |||||||
|  |  | ||||||
| Abstract Models | Abstract Models | ||||||
| ======================================== | ======================================== | ||||||
| .. autoclass:: prototorch.models.abstract.AbstractPrototypeModel | .. automodule:: prototorch.models.abstract | ||||||
|    :members: |  | ||||||
|  |  | ||||||
| .. autoclass:: prototorch.models.abstract.PrototypeImageModel |  | ||||||
|    :members: |    :members: | ||||||
|  |    :undoc-members: | ||||||
| @@ -8,7 +8,7 @@ Models | |||||||
|  |  | ||||||
| Unsupervised Methods | Unsupervised Methods | ||||||
| ----------------------------------------- | ----------------------------------------- | ||||||
| .. autoclass:: prototorch.models.unsupervised.KNN | .. autoclass:: prototorch.models.knn.KNN | ||||||
|    :members: |    :members: | ||||||
|  |  | ||||||
| .. autoclass:: prototorch.models.unsupervised.NeuralGas | .. autoclass:: prototorch.models.unsupervised.NeuralGas | ||||||
| @@ -80,9 +80,11 @@ Every prototypes is a center of a gaussian distribution of its class, generating | |||||||
| .. autoclass:: prototorch.models.probabilistic.RSLVQ | .. autoclass:: prototorch.models.probabilistic.RSLVQ | ||||||
|    :members: |    :members: | ||||||
|  |  | ||||||
| Missing: | :cite:t:`villmann2018` proposed two changes to RSLVQ: First incooperate the winning rank into the prior probability calculation. | ||||||
|  | And second use divergence as loss function. | ||||||
|  |  | ||||||
| - PLVQ | .. autoclass:: prototorch.models.probabilistic.PLVQ | ||||||
|  |    :members: | ||||||
|  |  | ||||||
| Classification by Component | Classification by Component | ||||||
| -------------------------------------------- | -------------------------------------------- | ||||||
|   | |||||||
| @@ -60,3 +60,14 @@ | |||||||
|     doi = {10.1162/neco.2009.11-08-908}, |     doi = {10.1162/neco.2009.11-08-908}, | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @InProceedings{villmann2018, | ||||||
|  |     author="Villmann, Andrea | ||||||
|  |     and Kaden, Marika | ||||||
|  |     and Saralajew, Sascha | ||||||
|  |     and Villmann, Thomas", | ||||||
|  |     title="Probabilistic Learning Vector Quantization with Cross-Entropy for Probabilistic Class Assignments in Classification Learning", | ||||||
|  |     booktitle="Artificial Intelligence and Soft Computing", | ||||||
|  |     year="2018", | ||||||
|  |     publisher="Springer International Publishing", | ||||||
|  | } | ||||||
|  |  | ||||||
|   | |||||||
| @@ -5,6 +5,7 @@ import argparse | |||||||
| import prototorch as pt | import prototorch as pt | ||||||
| import pytorch_lightning as pl | import pytorch_lightning as pl | ||||||
| import torch | import torch | ||||||
|  | from torchvision.transforms import Lambda | ||||||
|  |  | ||||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||||
|     # Command-line arguments |     # Command-line arguments | ||||||
| @@ -24,12 +25,15 @@ if __name__ == "__main__": | |||||||
|     # Hyperparameters |     # Hyperparameters | ||||||
|     hparams = dict( |     hparams = dict( | ||||||
|         distribution=[2, 2, 3], |         distribution=[2, 2, 3], | ||||||
|         lr=0.05, |         proto_lr=0.05, | ||||||
|         variance=0.1, |         lambd=0.1, | ||||||
|  |         input_dim=2, | ||||||
|  |         latent_dim=2, | ||||||
|  |         bb_lr=0.01, | ||||||
|     ) |     ) | ||||||
|  |  | ||||||
|     # Initialize the model |     # Initialize the model | ||||||
|     model = pt.models.probabilistic.RSLVQ( |     model = pt.models.probabilistic.PLVQ( | ||||||
|         hparams, |         hparams, | ||||||
|         optimizer=torch.optim.Adam, |         optimizer=torch.optim.Adam, | ||||||
|         # prototype_initializer=pt.components.SMI(train_ds), |         # prototype_initializer=pt.components.SMI(train_ds), | ||||||
| @@ -45,7 +49,7 @@ if __name__ == "__main__": | |||||||
|     print(model) |     print(model) | ||||||
|  |  | ||||||
|     # Callbacks |     # Callbacks | ||||||
|     vis = pt.models.VisGLVQ2D(data=train_ds) |     vis = pt.models.VisSiameseGLVQ2D(data=train_ds) | ||||||
|  |  | ||||||
|     # Setup trainer |     # Setup trainer | ||||||
|     trainer = pl.Trainer.from_argparse_args( |     trainer = pl.Trainer.from_argparse_args( | ||||||
|   | |||||||
| @@ -4,22 +4,11 @@ from importlib.metadata import PackageNotFoundError, version | |||||||
|  |  | ||||||
| from .callbacks import PrototypeConvergence, PruneLoserPrototypes | from .callbacks import PrototypeConvergence, PruneLoserPrototypes | ||||||
| from .cbc import CBC, ImageCBC | from .cbc import CBC, ImageCBC | ||||||
| from .glvq import ( | from .glvq import (GLVQ, GLVQ1, GLVQ21, GMLVQ, GRLVQ, LGMLVQ, LVQMLN, | ||||||
|     GLVQ, |                    ImageGLVQ, ImageGMLVQ, SiameseGLVQ, SiameseGMLVQ) | ||||||
|     GLVQ1, |  | ||||||
|     GLVQ21, |  | ||||||
|     GMLVQ, |  | ||||||
|     GRLVQ, |  | ||||||
|     LGMLVQ, |  | ||||||
|     LVQMLN, |  | ||||||
|     ImageGLVQ, |  | ||||||
|     ImageGMLVQ, |  | ||||||
|     SiameseGLVQ, |  | ||||||
|     SiameseGMLVQ, |  | ||||||
| ) |  | ||||||
| from .knn import KNN | from .knn import KNN | ||||||
| from .lvq import LVQ1, LVQ21, MedianLVQ | from .lvq import LVQ1, LVQ21, MedianLVQ | ||||||
| from .probabilistic import CELVQ, RSLVQ, SLVQ | from .probabilistic import CELVQ, PLVQ, RSLVQ, SLVQ | ||||||
| from .unsupervised import GrowingNeuralGas, HeskesSOM, KohonenSOM, NeuralGas | from .unsupervised import GrowingNeuralGas, HeskesSOM, KohonenSOM, NeuralGas | ||||||
| from .vis import * | from .vis import * | ||||||
|  |  | ||||||
|   | |||||||
| @@ -2,11 +2,13 @@ | |||||||
|  |  | ||||||
| import torch | import torch | ||||||
| from prototorch.functions.losses import nllr_loss, rslvq_loss | from prototorch.functions.losses import nllr_loss, rslvq_loss | ||||||
| from prototorch.functions.pooling import stratified_min_pooling, stratified_sum_pooling | from prototorch.functions.pooling import (stratified_min_pooling, | ||||||
| from prototorch.functions.transforms import gaussian |                                           stratified_sum_pooling) | ||||||
|  | from prototorch.functions.transforms import (GaussianPrior, | ||||||
|  |                                              RankScaledGaussianPrior) | ||||||
| from prototorch.modules import LambdaLayer, LossLayer | from prototorch.modules import LambdaLayer, LossLayer | ||||||
|  |  | ||||||
| from .glvq import GLVQ | from .glvq import GLVQ, SiameseGMLVQ | ||||||
|  |  | ||||||
|  |  | ||||||
| class CELVQ(GLVQ): | class CELVQ(GLVQ): | ||||||
| @@ -32,13 +34,12 @@ class ProbabilisticLVQ(GLVQ): | |||||||
|     def __init__(self, hparams, rejection_confidence=0.0, **kwargs): |     def __init__(self, hparams, rejection_confidence=0.0, **kwargs): | ||||||
|         super().__init__(hparams, **kwargs) |         super().__init__(hparams, **kwargs) | ||||||
|  |  | ||||||
|         self.conditional_distribution = gaussian |         self.conditional_distribution = None | ||||||
|         self.rejection_confidence = rejection_confidence |         self.rejection_confidence = rejection_confidence | ||||||
|  |  | ||||||
|     def forward(self, x): |     def forward(self, x): | ||||||
|         distances = self.compute_distances(x) |         distances = self.compute_distances(x) | ||||||
|         conditional = self.conditional_distribution(distances, |         conditional = self.conditional_distribution(distances) | ||||||
|                                                     self.hparams.variance) |  | ||||||
|         prior = (1. / self.num_prototypes) * torch.ones(self.num_prototypes, |         prior = (1. / self.num_prototypes) * torch.ones(self.num_prototypes, | ||||||
|                                                         device=self.device) |                                                         device=self.device) | ||||||
|         posterior = conditional * prior |         posterior = conditional * prior | ||||||
| @@ -66,6 +67,7 @@ class SLVQ(ProbabilisticLVQ): | |||||||
|     def __init__(self, *args, **kwargs): |     def __init__(self, *args, **kwargs): | ||||||
|         super().__init__(*args, **kwargs) |         super().__init__(*args, **kwargs) | ||||||
|         self.loss = LossLayer(nllr_loss) |         self.loss = LossLayer(nllr_loss) | ||||||
|  |         self.conditional_distribution = GaussianPrior(self.hparams.variance) | ||||||
|  |  | ||||||
|  |  | ||||||
| class RSLVQ(ProbabilisticLVQ): | class RSLVQ(ProbabilisticLVQ): | ||||||
| @@ -73,3 +75,25 @@ class RSLVQ(ProbabilisticLVQ): | |||||||
|     def __init__(self, *args, **kwargs): |     def __init__(self, *args, **kwargs): | ||||||
|         super().__init__(*args, **kwargs) |         super().__init__(*args, **kwargs) | ||||||
|         self.loss = LossLayer(rslvq_loss) |         self.loss = LossLayer(rslvq_loss) | ||||||
|  |         self.conditional_distribution = GaussianPrior(self.hparams.variance) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class PLVQ(ProbabilisticLVQ, SiameseGMLVQ): | ||||||
|  |     """Probabilistic Learning Vector Quantization. | ||||||
|  |  | ||||||
|  |     TODO: Use Backbone LVQ instead | ||||||
|  |     """ | ||||||
|  |     def __init__(self, *args, **kwargs): | ||||||
|  |         super().__init__(*args, **kwargs) | ||||||
|  |         self.conditional_distribution = RankScaledGaussianPrior( | ||||||
|  |             self.hparams.lambd) | ||||||
|  |         self.loss = torch.nn.KLDivLoss() | ||||||
|  |  | ||||||
|  |     def training_step(self, batch, batch_idx, optimizer_idx=None): | ||||||
|  |         x, y = batch | ||||||
|  |         out = self.forward(x) | ||||||
|  |         y_dist = torch.nn.functional.one_hot( | ||||||
|  |             y.long(), num_classes=self.num_classes).float() | ||||||
|  |         batch_loss = self.loss(out, y_dist) | ||||||
|  |         loss = batch_loss.sum(dim=0) | ||||||
|  |         return loss | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user