[FEATURE] Add Growing Neural Gas

This commit is contained in:
Alexander Engelsberger 2021-06-01 17:19:43 +02:00
parent 1636c84778
commit 9c1a41997b
4 changed files with 177 additions and 0 deletions

View File

@ -31,6 +31,7 @@ The plugin should then be available for use in your Python environment as
- Learning Vector Quantization Multi-Layer Network (LVQMLN)
- Siamese GLVQ
- Neural Gas (NG)
- Growing Neural Gas (GNG)
## Work in Progress

View File

@ -14,6 +14,8 @@ Unsupervised Methods
.. autoclass:: prototorch.models.unsupervised.NeuralGas
:members:
.. autoclass:: prototorch.models.unsupervised.GrowingNeuralGas
:members:
Classical Learning Vector Quantization
-----------------------------------------

45
examples/gng_iris.py Normal file
View File

@ -0,0 +1,45 @@
import argparse
import prototorch as pt
import pytorch_lightning as pl
from prototorch.components.initializers import SelectionInitializer
from prototorch.datasets import Iris
from prototorch.models.unsupervised import GrowingNeuralGas
from torch.utils.data import DataLoader
if __name__ == "__main__":
# Command-line arguments
parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser)
args = parser.parse_args()
# Prepare the data
train_ds = Iris(dims=[0, 2])
train_loader = DataLoader(train_ds, batch_size=32)
# Hyperparameters
hparams = dict(num_prototypes=2,
lr=0.1,
prototype_initializer=SelectionInitializer(train_ds.data))
# Initialize the model
model = GrowingNeuralGas(hparams)
# Model summary
print(model)
# Callbacks
vis = pt.models.VisNG2D(data=train_loader)
# Setup trainer
trainer = pl.Trainer.from_argparse_args(
args,
max_epochs=100,
callbacks=[vis],
)
# Training loop
trainer.fit(model, train_loader)
# Model summary
print(model)

View File

@ -1,7 +1,9 @@
"""Unsupervised prototype learning algorithms."""
import logging
import warnings
import pytorch_lightning as pl
import torch
import torchmetrics
from prototorch.components import Components, LabeledComponents
@ -10,6 +12,7 @@ from prototorch.components.initializers import ZerosInitializer, parse_data_arg
from prototorch.functions.competitions import knnc
from prototorch.functions.distances import euclidean_distance
from prototorch.modules.losses import NeuralGasEnergy
from pytorch_lightning.callbacks import Callback
from .abstract import AbstractPrototypeModel
@ -19,6 +22,61 @@ class EuclideanDistance(torch.nn.Module):
return euclidean_distance(x, y)
class GNGCallback(Callback):
"""GNG Callback.
Applies growing algorithm based on accumulated error and topology.
Based on "A Growing Neural Gas Network Learns Topologies" by Bernd Fritzke.
"""
def __init__(self, reduction=0.1, freq=10):
self.reduction = reduction
self.freq = freq
def on_epoch_end(self, trainer: pl.Trainer, pl_module):
if (trainer.current_epoch + 1) % self.freq == 0:
# Get information
errors = pl_module.errors
topology: ConnectionTopology = pl_module.topology_layer
components: pt.components.Components = pl_module.proto_layer.components
# Insertion point
worst = torch.argmax(errors)
neighbours = topology.get_neighbours(worst)[0]
if len(neighbours) == 0:
logging.log("No Neighbour pair found")
return
neighbours_errors = errors[neighbours]
worst_neighbour = neighbours[torch.argmax(neighbours_errors)]
# New Prototype
new_component = 0.5 * (components[worst] +
components[worst_neighbour])
new_components = torch.vstack([components, new_component])
# Add component
pl_module.proto_layer.register_parameter(
"_components", torch.nn.parameter.Parameter(new_components))
# Adjust Topology
topology.add_prototype()
topology.add_connection(worst, -1)
topology.add_connection(worst_neighbour, -1)
topology.remove_connection(worst, worst_neighbour)
# New errors
worst_error = errors[worst].unsqueeze(0)
pl_module.errors = torch.cat([pl_module.errors, worst_error])
pl_module.errors[worst] = errors[worst] * self.reduction
pl_module.errors[
worst_neighbour] = errors[worst_neighbour] * self.reduction
trainer.accelerator_backend.setup_optimizers(trainer)
class ConnectionTopology(torch.nn.Module):
def __init__(self, agelimit, num_prototypes):
super().__init__()
@ -33,10 +91,44 @@ class ConnectionTopology(torch.nn.Module):
for element in order:
i0, i1 = element[0], element[1]
self.cmat[i0][i1] = 1
self.cmat[i1][i0] = 1
self.age[i0][i1] = 0
self.age[i1][i0] = 0
self.age[i0][self.cmat[i0] == 1] += 1
self.age[i1][self.cmat[i1] == 1] += 1
self.cmat[i0][self.age[i0] > self.agelimit] = 0
self.cmat[i1][self.age[i1] > self.agelimit] = 0
def get_neighbours(self, position):
return torch.where(self.cmat[position])
def add_prototype(self):
new_cmat = torch.zeros([dim + 1 for dim in self.cmat.shape])
new_cmat[:-1, :-1] = self.cmat
self.cmat = new_cmat
new_age = torch.zeros([dim + 1 for dim in self.age.shape])
new_age[:-1, :-1] = self.age
self.age = new_age
def add_connection(self, a, b):
self.cmat[a][b] = 1
self.cmat[b][a] = 1
self.age[a][b] = 0
self.age[b][a] = 0
def remove_connection(self, a, b):
self.cmat[a][b] = 0
self.cmat[b][a] = 0
self.age[a][b] = 0
self.age[b][a] = 0
def extra_repr(self):
return f"agelimit: {self.agelimit}"
@ -126,3 +218,40 @@ class NeuralGas(AbstractPrototypeModel):
self.topology_layer(d)
return cost
class GrowingNeuralGas(NeuralGas):
def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs)
# defaults
self.hparams.setdefault("step_reduction", 0.5)
self.hparams.setdefault("insert_reduction", 0.1)
self.hparams.setdefault("insert_freq", 10)
self.register_buffer("errors",
torch.zeros(self.hparams.num_prototypes))
def training_step(self, train_batch, _batch_idx):
x = train_batch[0]
protos = self.proto_layer()
d = self.distance_layer(x, protos)
cost, order = self.energy_layer(d)
winner = order[:, 0]
mask = torch.zeros_like(d)
mask[torch.arange(len(mask)), winner] = 1.0
winner_distances = d * mask
self.errors += torch.sum(winner_distances * winner_distances, dim=0)
self.errors *= self.hparams.step_reduction
self.topology_layer(d)
return cost
def configure_callbacks(self):
return [
GNGCallback(reduction=self.hparams.insert_reduction,
freq=self.hparams.insert_freq)
]