[WIP] Add SOM
This commit is contained in:
		@@ -26,8 +26,8 @@ be available for use in your Python environment as `prototorch.models`.
 | 
				
			|||||||
- Generalized Learning Vector Quantization (GLVQ)
 | 
					- Generalized Learning Vector Quantization (GLVQ)
 | 
				
			||||||
- Generalized Relevance Learning Vector Quantization (GRLVQ)
 | 
					- Generalized Relevance Learning Vector Quantization (GRLVQ)
 | 
				
			||||||
- Generalized Matrix Learning Vector Quantization (GMLVQ)
 | 
					- Generalized Matrix Learning Vector Quantization (GMLVQ)
 | 
				
			||||||
- Localized and Generalized Matrix Learning Vector Quantization (LGMLVQ)
 | 
					 | 
				
			||||||
- Limited-Rank Matrix Learning Vector Quantization (LiRaMLVQ)
 | 
					- Limited-Rank Matrix Learning Vector Quantization (LiRaMLVQ)
 | 
				
			||||||
 | 
					- Localized and Generalized Matrix Learning Vector Quantization (LGMLVQ)
 | 
				
			||||||
- Learning Vector Quantization Multi-Layer Network (LVQMLN)
 | 
					- Learning Vector Quantization Multi-Layer Network (LVQMLN)
 | 
				
			||||||
- Siamese GLVQ
 | 
					- Siamese GLVQ
 | 
				
			||||||
- Cross-Entropy Learning Vector Quantization (CELVQ)
 | 
					- Cross-Entropy Learning Vector Quantization (CELVQ)
 | 
				
			||||||
@@ -43,6 +43,7 @@ be available for use in your Python environment as `prototorch.models`.
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
- Classification-By-Components Network (CBC)
 | 
					- Classification-By-Components Network (CBC)
 | 
				
			||||||
- Learning Vector Quantization 2.1 (LVQ2.1)
 | 
					- Learning Vector Quantization 2.1 (LVQ2.1)
 | 
				
			||||||
 | 
					- Self-Organizing-Map (SOM)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
## Planned models
 | 
					## Planned models
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										112
									
								
								examples/ksom_colors.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										112
									
								
								examples/ksom_colors.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,112 @@
 | 
				
			|||||||
 | 
					"""Kohonen Self Organizing Map."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import argparse
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import prototorch as pt
 | 
				
			||||||
 | 
					import pytorch_lightning as pl
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					from matplotlib import pyplot as plt
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def hex_to_rgb(hex_values):
 | 
				
			||||||
 | 
					    for v in hex_values:
 | 
				
			||||||
 | 
					        v = v.lstrip('#')
 | 
				
			||||||
 | 
					        lv = len(v)
 | 
				
			||||||
 | 
					        c = [int(v[i:i + lv // 3], 16) for i in range(0, lv, lv // 3)]
 | 
				
			||||||
 | 
					        yield c
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def rgb_to_hex(rgb_values):
 | 
				
			||||||
 | 
					    for v in rgb_values:
 | 
				
			||||||
 | 
					        c = "%02x%02x%02x" % tuple(v)
 | 
				
			||||||
 | 
					        yield c
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class Vis2DColorSOM(pl.Callback):
 | 
				
			||||||
 | 
					    def __init__(self, data, title="ColorSOMe", pause_time=0.1):
 | 
				
			||||||
 | 
					        super().__init__()
 | 
				
			||||||
 | 
					        self.title = title
 | 
				
			||||||
 | 
					        self.fig = plt.figure(self.title)
 | 
				
			||||||
 | 
					        self.data = data
 | 
				
			||||||
 | 
					        self.pause_time = pause_time
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def on_epoch_end(self, trainer, pl_module):
 | 
				
			||||||
 | 
					        ax = self.fig.gca()
 | 
				
			||||||
 | 
					        ax.cla()
 | 
				
			||||||
 | 
					        ax.set_title(self.title)
 | 
				
			||||||
 | 
					        h, w = pl_module._grid.shape[:2]
 | 
				
			||||||
 | 
					        protos = pl_module.prototypes.view(h, w, 3)
 | 
				
			||||||
 | 
					        ax.imshow(protos)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Overlay color names
 | 
				
			||||||
 | 
					        d = pl_module.compute_distances(self.data)
 | 
				
			||||||
 | 
					        wp = pl_module.predict_from_distances(d)
 | 
				
			||||||
 | 
					        for i, iloc in enumerate(wp):
 | 
				
			||||||
 | 
					            plt.text(iloc[1],
 | 
				
			||||||
 | 
					                     iloc[0],
 | 
				
			||||||
 | 
					                     cnames[i],
 | 
				
			||||||
 | 
					                     ha="center",
 | 
				
			||||||
 | 
					                     va="center",
 | 
				
			||||||
 | 
					                     bbox=dict(facecolor="white", alpha=0.5, lw=0))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        plt.pause(self.pause_time)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					if __name__ == "__main__":
 | 
				
			||||||
 | 
					    # Command-line arguments
 | 
				
			||||||
 | 
					    parser = argparse.ArgumentParser()
 | 
				
			||||||
 | 
					    parser = pl.Trainer.add_argparse_args(parser)
 | 
				
			||||||
 | 
					    args = parser.parse_args()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Reproducibility
 | 
				
			||||||
 | 
					    pl.utilities.seed.seed_everything(seed=42)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Prepare the data
 | 
				
			||||||
 | 
					    hex_colors = [
 | 
				
			||||||
 | 
					        "#000000", "#0000ff", "#00007f", "#1f86ff", "#5466aa", "#997fff",
 | 
				
			||||||
 | 
					        "#00ff00", "#ff0000", "#00ffff", "#ff00ff", "#ffff00", "#ffffff",
 | 
				
			||||||
 | 
					        "#545454", "#7f7f7f", "#a8a8a8"
 | 
				
			||||||
 | 
					    ]
 | 
				
			||||||
 | 
					    cnames = [
 | 
				
			||||||
 | 
					        "black", "blue", "darkblue", "skyblue", "greyblue", "lilac", "green",
 | 
				
			||||||
 | 
					        "red", "cyan", "violet", "yellow", "white", "darkgrey", "mediumgrey",
 | 
				
			||||||
 | 
					        "lightgrey"
 | 
				
			||||||
 | 
					    ]
 | 
				
			||||||
 | 
					    colors = list(hex_to_rgb(hex_colors))
 | 
				
			||||||
 | 
					    data = torch.Tensor(colors) / 255.0
 | 
				
			||||||
 | 
					    train_ds = torch.utils.data.TensorDataset(data)
 | 
				
			||||||
 | 
					    train_loader = torch.utils.data.DataLoader(train_ds, batch_size=8)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Hyperparameters
 | 
				
			||||||
 | 
					    hparams = dict(
 | 
				
			||||||
 | 
					        shape=(18, 32),
 | 
				
			||||||
 | 
					        alpha=1.0,
 | 
				
			||||||
 | 
					        sigma=3,
 | 
				
			||||||
 | 
					        lr=0.1,
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Initialize the model
 | 
				
			||||||
 | 
					    model = pt.models.KohonenSOM(
 | 
				
			||||||
 | 
					        hparams,
 | 
				
			||||||
 | 
					        prototype_initializer=pt.components.Random(3),
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Compute intermediate input and output sizes
 | 
				
			||||||
 | 
					    model.example_input_array = torch.zeros(4, 3)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Model summary
 | 
				
			||||||
 | 
					    print(model)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Callbacks
 | 
				
			||||||
 | 
					    vis = Vis2DColorSOM(data=data)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Setup trainer
 | 
				
			||||||
 | 
					    trainer = pl.Trainer.from_argparse_args(
 | 
				
			||||||
 | 
					        args,
 | 
				
			||||||
 | 
					        max_epochs=300,
 | 
				
			||||||
 | 
					        callbacks=[vis],
 | 
				
			||||||
 | 
					        weights_summary="full",
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Training loop
 | 
				
			||||||
 | 
					    trainer.fit(model, train_loader)
 | 
				
			||||||
@@ -20,7 +20,7 @@ from .glvq import (
 | 
				
			|||||||
from .knn import KNN
 | 
					from .knn import KNN
 | 
				
			||||||
from .lvq import LVQ1, LVQ21, MedianLVQ
 | 
					from .lvq import LVQ1, LVQ21, MedianLVQ
 | 
				
			||||||
from .probabilistic import CELVQ, RSLVQ, LikelihoodRatioLVQ
 | 
					from .probabilistic import CELVQ, RSLVQ, LikelihoodRatioLVQ
 | 
				
			||||||
from .unsupervised import GrowingNeuralGas, NeuralGas
 | 
					from .unsupervised import GrowingNeuralGas, HeskesSOM, KohonenSOM, NeuralGas
 | 
				
			||||||
from .vis import *
 | 
					from .vis import *
 | 
				
			||||||
 | 
					
 | 
				
			||||||
__version__ = "0.1.7"
 | 
					__version__ = "0.1.7"
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -1,25 +1,76 @@
 | 
				
			|||||||
"""Unsupervised prototype learning algorithms."""
 | 
					"""Unsupervised prototype learning algorithms."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import logging
 | 
					 | 
				
			||||||
import warnings
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
import prototorch as pt
 | 
					 | 
				
			||||||
import pytorch_lightning as pl
 | 
					 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
import torchmetrics
 | 
					from prototorch.functions.competitions import wtac
 | 
				
			||||||
from prototorch.components import Components, LabeledComponents
 | 
					from prototorch.functions.distances import squared_euclidean_distance
 | 
				
			||||||
from prototorch.components.initializers import ZerosInitializer
 | 
					from prototorch.functions.helper import get_flat
 | 
				
			||||||
from prototorch.functions.competitions import knnc
 | 
					 | 
				
			||||||
from prototorch.functions.distances import euclidean_distance
 | 
					 | 
				
			||||||
from prototorch.modules import LambdaLayer
 | 
					from prototorch.modules import LambdaLayer
 | 
				
			||||||
from prototorch.modules.losses import NeuralGasEnergy
 | 
					from prototorch.modules.losses import NeuralGasEnergy
 | 
				
			||||||
from pytorch_lightning.callbacks import Callback
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
from .abstract import UnsupervisedPrototypeModel
 | 
					from .abstract import NonGradientMixin, UnsupervisedPrototypeModel
 | 
				
			||||||
from .callbacks import GNGCallback
 | 
					from .callbacks import GNGCallback
 | 
				
			||||||
from .extras import ConnectionTopology
 | 
					from .extras import ConnectionTopology
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class KohonenSOM(NonGradientMixin, UnsupervisedPrototypeModel):
 | 
				
			||||||
 | 
					    """Kohonen Self-Organizing-Map.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    TODO Allow non-2D grids
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    def __init__(self, hparams, **kwargs):
 | 
				
			||||||
 | 
					        h, w = hparams.get("shape")
 | 
				
			||||||
 | 
					        # Ignore `num_prototypes`
 | 
				
			||||||
 | 
					        hparams["num_prototypes"] = h * w
 | 
				
			||||||
 | 
					        distance_fn = kwargs.pop("distance_fn", squared_euclidean_distance)
 | 
				
			||||||
 | 
					        super().__init__(hparams, distance_fn=distance_fn, **kwargs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Hyperparameters
 | 
				
			||||||
 | 
					        self.save_hyperparameters(hparams)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Default hparams
 | 
				
			||||||
 | 
					        self.hparams.setdefault("alpha", 0.3)
 | 
				
			||||||
 | 
					        self.hparams.setdefault("sigma", max(h, w) / 2.0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Additional parameters
 | 
				
			||||||
 | 
					        x, y = torch.arange(h), torch.arange(w)
 | 
				
			||||||
 | 
					        grid = torch.stack(torch.meshgrid(x, y), dim=-1)
 | 
				
			||||||
 | 
					        self.register_buffer("_grid", grid)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def predict_from_distances(self, distances):
 | 
				
			||||||
 | 
					        grid = self._grid.view(-1, 2)
 | 
				
			||||||
 | 
					        wp = wtac(distances, grid)
 | 
				
			||||||
 | 
					        return wp
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def training_step(self, train_batch, batch_idx):
 | 
				
			||||||
 | 
					        # x = train_batch
 | 
				
			||||||
 | 
					        # TODO Check if the batch has labels
 | 
				
			||||||
 | 
					        x = train_batch[0]
 | 
				
			||||||
 | 
					        d = self.compute_distances(x)
 | 
				
			||||||
 | 
					        wp = self.predict_from_distances(d)
 | 
				
			||||||
 | 
					        grid = self._grid.view(-1, 2)
 | 
				
			||||||
 | 
					        gd = squared_euclidean_distance(wp, grid)
 | 
				
			||||||
 | 
					        nh = torch.exp(-gd / self.hparams.sigma**2)
 | 
				
			||||||
 | 
					        protos = self.proto_layer.components
 | 
				
			||||||
 | 
					        diff = x.unsqueeze(dim=1) - protos
 | 
				
			||||||
 | 
					        delta = self.hparams.lr * self.hparams.alpha * nh.unsqueeze(-1) * diff
 | 
				
			||||||
 | 
					        updated_protos = protos + delta.sum(dim=0)
 | 
				
			||||||
 | 
					        self.proto_layer.load_state_dict({"_components": updated_protos},
 | 
				
			||||||
 | 
					                                         strict=False)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def extra_repr(self):
 | 
				
			||||||
 | 
					        return f"(grid): (shape: {tuple(self._grid.shape)})"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class HeskesSOM(UnsupervisedPrototypeModel):
 | 
				
			||||||
 | 
					    def __init__(self, hparams, **kwargs):
 | 
				
			||||||
 | 
					        super().__init__(hparams, **kwargs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def training_step(self, train_batch, batch_idx):
 | 
				
			||||||
 | 
					        # TODO Implement me!
 | 
				
			||||||
 | 
					        raise NotImplementedError()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class NeuralGas(UnsupervisedPrototypeModel):
 | 
					class NeuralGas(UnsupervisedPrototypeModel):
 | 
				
			||||||
    def __init__(self, hparams, **kwargs):
 | 
					    def __init__(self, hparams, **kwargs):
 | 
				
			||||||
        super().__init__(hparams, **kwargs)
 | 
					        super().__init__(hparams, **kwargs)
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user