[WIP] Add SOM

This commit is contained in:
Jensun Ravichandran 2021-06-07 18:44:15 +02:00
parent b031382072
commit c7b5c88776
4 changed files with 178 additions and 14 deletions

View File

@ -26,8 +26,8 @@ be available for use in your Python environment as `prototorch.models`.
- Generalized Learning Vector Quantization (GLVQ) - Generalized Learning Vector Quantization (GLVQ)
- Generalized Relevance Learning Vector Quantization (GRLVQ) - Generalized Relevance Learning Vector Quantization (GRLVQ)
- Generalized Matrix Learning Vector Quantization (GMLVQ) - Generalized Matrix Learning Vector Quantization (GMLVQ)
- Localized and Generalized Matrix Learning Vector Quantization (LGMLVQ)
- Limited-Rank Matrix Learning Vector Quantization (LiRaMLVQ) - Limited-Rank Matrix Learning Vector Quantization (LiRaMLVQ)
- Localized and Generalized Matrix Learning Vector Quantization (LGMLVQ)
- Learning Vector Quantization Multi-Layer Network (LVQMLN) - Learning Vector Quantization Multi-Layer Network (LVQMLN)
- Siamese GLVQ - Siamese GLVQ
- Cross-Entropy Learning Vector Quantization (CELVQ) - Cross-Entropy Learning Vector Quantization (CELVQ)
@ -43,6 +43,7 @@ be available for use in your Python environment as `prototorch.models`.
- Classification-By-Components Network (CBC) - Classification-By-Components Network (CBC)
- Learning Vector Quantization 2.1 (LVQ2.1) - Learning Vector Quantization 2.1 (LVQ2.1)
- Self-Organizing-Map (SOM)
## Planned models ## Planned models

112
examples/ksom_colors.py Normal file
View File

@ -0,0 +1,112 @@
"""Kohonen Self Organizing Map."""
import argparse
import prototorch as pt
import pytorch_lightning as pl
import torch
from matplotlib import pyplot as plt
def hex_to_rgb(hex_values):
for v in hex_values:
v = v.lstrip('#')
lv = len(v)
c = [int(v[i:i + lv // 3], 16) for i in range(0, lv, lv // 3)]
yield c
def rgb_to_hex(rgb_values):
for v in rgb_values:
c = "%02x%02x%02x" % tuple(v)
yield c
class Vis2DColorSOM(pl.Callback):
def __init__(self, data, title="ColorSOMe", pause_time=0.1):
super().__init__()
self.title = title
self.fig = plt.figure(self.title)
self.data = data
self.pause_time = pause_time
def on_epoch_end(self, trainer, pl_module):
ax = self.fig.gca()
ax.cla()
ax.set_title(self.title)
h, w = pl_module._grid.shape[:2]
protos = pl_module.prototypes.view(h, w, 3)
ax.imshow(protos)
# Overlay color names
d = pl_module.compute_distances(self.data)
wp = pl_module.predict_from_distances(d)
for i, iloc in enumerate(wp):
plt.text(iloc[1],
iloc[0],
cnames[i],
ha="center",
va="center",
bbox=dict(facecolor="white", alpha=0.5, lw=0))
plt.pause(self.pause_time)
if __name__ == "__main__":
# Command-line arguments
parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser)
args = parser.parse_args()
# Reproducibility
pl.utilities.seed.seed_everything(seed=42)
# Prepare the data
hex_colors = [
"#000000", "#0000ff", "#00007f", "#1f86ff", "#5466aa", "#997fff",
"#00ff00", "#ff0000", "#00ffff", "#ff00ff", "#ffff00", "#ffffff",
"#545454", "#7f7f7f", "#a8a8a8"
]
cnames = [
"black", "blue", "darkblue", "skyblue", "greyblue", "lilac", "green",
"red", "cyan", "violet", "yellow", "white", "darkgrey", "mediumgrey",
"lightgrey"
]
colors = list(hex_to_rgb(hex_colors))
data = torch.Tensor(colors) / 255.0
train_ds = torch.utils.data.TensorDataset(data)
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=8)
# Hyperparameters
hparams = dict(
shape=(18, 32),
alpha=1.0,
sigma=3,
lr=0.1,
)
# Initialize the model
model = pt.models.KohonenSOM(
hparams,
prototype_initializer=pt.components.Random(3),
)
# Compute intermediate input and output sizes
model.example_input_array = torch.zeros(4, 3)
# Model summary
print(model)
# Callbacks
vis = Vis2DColorSOM(data=data)
# Setup trainer
trainer = pl.Trainer.from_argparse_args(
args,
max_epochs=300,
callbacks=[vis],
weights_summary="full",
)
# Training loop
trainer.fit(model, train_loader)

View File

@ -20,7 +20,7 @@ from .glvq import (
from .knn import KNN from .knn import KNN
from .lvq import LVQ1, LVQ21, MedianLVQ from .lvq import LVQ1, LVQ21, MedianLVQ
from .probabilistic import CELVQ, RSLVQ, LikelihoodRatioLVQ from .probabilistic import CELVQ, RSLVQ, LikelihoodRatioLVQ
from .unsupervised import GrowingNeuralGas, NeuralGas from .unsupervised import GrowingNeuralGas, HeskesSOM, KohonenSOM, NeuralGas
from .vis import * from .vis import *
__version__ = "0.1.7" __version__ = "0.1.7"

View File

@ -1,25 +1,76 @@
"""Unsupervised prototype learning algorithms.""" """Unsupervised prototype learning algorithms."""
import logging
import warnings
import prototorch as pt
import pytorch_lightning as pl
import torch import torch
import torchmetrics from prototorch.functions.competitions import wtac
from prototorch.components import Components, LabeledComponents from prototorch.functions.distances import squared_euclidean_distance
from prototorch.components.initializers import ZerosInitializer from prototorch.functions.helper import get_flat
from prototorch.functions.competitions import knnc
from prototorch.functions.distances import euclidean_distance
from prototorch.modules import LambdaLayer from prototorch.modules import LambdaLayer
from prototorch.modules.losses import NeuralGasEnergy from prototorch.modules.losses import NeuralGasEnergy
from pytorch_lightning.callbacks import Callback
from .abstract import UnsupervisedPrototypeModel from .abstract import NonGradientMixin, UnsupervisedPrototypeModel
from .callbacks import GNGCallback from .callbacks import GNGCallback
from .extras import ConnectionTopology from .extras import ConnectionTopology
class KohonenSOM(NonGradientMixin, UnsupervisedPrototypeModel):
"""Kohonen Self-Organizing-Map.
TODO Allow non-2D grids
"""
def __init__(self, hparams, **kwargs):
h, w = hparams.get("shape")
# Ignore `num_prototypes`
hparams["num_prototypes"] = h * w
distance_fn = kwargs.pop("distance_fn", squared_euclidean_distance)
super().__init__(hparams, distance_fn=distance_fn, **kwargs)
# Hyperparameters
self.save_hyperparameters(hparams)
# Default hparams
self.hparams.setdefault("alpha", 0.3)
self.hparams.setdefault("sigma", max(h, w) / 2.0)
# Additional parameters
x, y = torch.arange(h), torch.arange(w)
grid = torch.stack(torch.meshgrid(x, y), dim=-1)
self.register_buffer("_grid", grid)
def predict_from_distances(self, distances):
grid = self._grid.view(-1, 2)
wp = wtac(distances, grid)
return wp
def training_step(self, train_batch, batch_idx):
# x = train_batch
# TODO Check if the batch has labels
x = train_batch[0]
d = self.compute_distances(x)
wp = self.predict_from_distances(d)
grid = self._grid.view(-1, 2)
gd = squared_euclidean_distance(wp, grid)
nh = torch.exp(-gd / self.hparams.sigma**2)
protos = self.proto_layer.components
diff = x.unsqueeze(dim=1) - protos
delta = self.hparams.lr * self.hparams.alpha * nh.unsqueeze(-1) * diff
updated_protos = protos + delta.sum(dim=0)
self.proto_layer.load_state_dict({"_components": updated_protos},
strict=False)
def extra_repr(self):
return f"(grid): (shape: {tuple(self._grid.shape)})"
class HeskesSOM(UnsupervisedPrototypeModel):
def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs)
def training_step(self, train_batch, batch_idx):
# TODO Implement me!
raise NotImplementedError()
class NeuralGas(UnsupervisedPrototypeModel): class NeuralGas(UnsupervisedPrototypeModel):
def __init__(self, hparams, **kwargs): def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs) super().__init__(hparams, **kwargs)