[FEATURE] Add Growing Neural Gas
This commit is contained in:
parent
1636c84778
commit
9c1a41997b
@ -31,6 +31,7 @@ The plugin should then be available for use in your Python environment as
|
||||
- Learning Vector Quantization Multi-Layer Network (LVQMLN)
|
||||
- Siamese GLVQ
|
||||
- Neural Gas (NG)
|
||||
- Growing Neural Gas (GNG)
|
||||
|
||||
## Work in Progress
|
||||
|
||||
|
@ -14,6 +14,8 @@ Unsupervised Methods
|
||||
.. autoclass:: prototorch.models.unsupervised.NeuralGas
|
||||
:members:
|
||||
|
||||
.. autoclass:: prototorch.models.unsupervised.GrowingNeuralGas
|
||||
:members:
|
||||
|
||||
Classical Learning Vector Quantization
|
||||
-----------------------------------------
|
||||
|
45
examples/gng_iris.py
Normal file
45
examples/gng_iris.py
Normal file
@ -0,0 +1,45 @@
|
||||
import argparse
|
||||
|
||||
import prototorch as pt
|
||||
import pytorch_lightning as pl
|
||||
from prototorch.components.initializers import SelectionInitializer
|
||||
from prototorch.datasets import Iris
|
||||
from prototorch.models.unsupervised import GrowingNeuralGas
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Command-line arguments
|
||||
parser = argparse.ArgumentParser()
|
||||
parser = pl.Trainer.add_argparse_args(parser)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Prepare the data
|
||||
train_ds = Iris(dims=[0, 2])
|
||||
train_loader = DataLoader(train_ds, batch_size=32)
|
||||
|
||||
# Hyperparameters
|
||||
hparams = dict(num_prototypes=2,
|
||||
lr=0.1,
|
||||
prototype_initializer=SelectionInitializer(train_ds.data))
|
||||
|
||||
# Initialize the model
|
||||
model = GrowingNeuralGas(hparams)
|
||||
|
||||
# Model summary
|
||||
print(model)
|
||||
|
||||
# Callbacks
|
||||
vis = pt.models.VisNG2D(data=train_loader)
|
||||
|
||||
# Setup trainer
|
||||
trainer = pl.Trainer.from_argparse_args(
|
||||
args,
|
||||
max_epochs=100,
|
||||
callbacks=[vis],
|
||||
)
|
||||
|
||||
# Training loop
|
||||
trainer.fit(model, train_loader)
|
||||
|
||||
# Model summary
|
||||
print(model)
|
@ -1,7 +1,9 @@
|
||||
"""Unsupervised prototype learning algorithms."""
|
||||
|
||||
import logging
|
||||
import warnings
|
||||
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
import torchmetrics
|
||||
from prototorch.components import Components, LabeledComponents
|
||||
@ -10,6 +12,7 @@ from prototorch.components.initializers import ZerosInitializer, parse_data_arg
|
||||
from prototorch.functions.competitions import knnc
|
||||
from prototorch.functions.distances import euclidean_distance
|
||||
from prototorch.modules.losses import NeuralGasEnergy
|
||||
from pytorch_lightning.callbacks import Callback
|
||||
|
||||
from .abstract import AbstractPrototypeModel
|
||||
|
||||
@ -19,6 +22,61 @@ class EuclideanDistance(torch.nn.Module):
|
||||
return euclidean_distance(x, y)
|
||||
|
||||
|
||||
class GNGCallback(Callback):
|
||||
"""GNG Callback.
|
||||
|
||||
Applies growing algorithm based on accumulated error and topology.
|
||||
|
||||
Based on "A Growing Neural Gas Network Learns Topologies" by Bernd Fritzke.
|
||||
"""
|
||||
def __init__(self, reduction=0.1, freq=10):
|
||||
self.reduction = reduction
|
||||
self.freq = freq
|
||||
|
||||
def on_epoch_end(self, trainer: pl.Trainer, pl_module):
|
||||
if (trainer.current_epoch + 1) % self.freq == 0:
|
||||
# Get information
|
||||
errors = pl_module.errors
|
||||
topology: ConnectionTopology = pl_module.topology_layer
|
||||
components: pt.components.Components = pl_module.proto_layer.components
|
||||
|
||||
# Insertion point
|
||||
worst = torch.argmax(errors)
|
||||
|
||||
neighbours = topology.get_neighbours(worst)[0]
|
||||
|
||||
if len(neighbours) == 0:
|
||||
logging.log("No Neighbour pair found")
|
||||
return
|
||||
|
||||
neighbours_errors = errors[neighbours]
|
||||
worst_neighbour = neighbours[torch.argmax(neighbours_errors)]
|
||||
|
||||
# New Prototype
|
||||
new_component = 0.5 * (components[worst] +
|
||||
components[worst_neighbour])
|
||||
new_components = torch.vstack([components, new_component])
|
||||
|
||||
# Add component
|
||||
pl_module.proto_layer.register_parameter(
|
||||
"_components", torch.nn.parameter.Parameter(new_components))
|
||||
|
||||
# Adjust Topology
|
||||
topology.add_prototype()
|
||||
topology.add_connection(worst, -1)
|
||||
topology.add_connection(worst_neighbour, -1)
|
||||
topology.remove_connection(worst, worst_neighbour)
|
||||
|
||||
# New errors
|
||||
worst_error = errors[worst].unsqueeze(0)
|
||||
pl_module.errors = torch.cat([pl_module.errors, worst_error])
|
||||
pl_module.errors[worst] = errors[worst] * self.reduction
|
||||
pl_module.errors[
|
||||
worst_neighbour] = errors[worst_neighbour] * self.reduction
|
||||
|
||||
trainer.accelerator_backend.setup_optimizers(trainer)
|
||||
|
||||
|
||||
class ConnectionTopology(torch.nn.Module):
|
||||
def __init__(self, agelimit, num_prototypes):
|
||||
super().__init__()
|
||||
@ -33,10 +91,44 @@ class ConnectionTopology(torch.nn.Module):
|
||||
|
||||
for element in order:
|
||||
i0, i1 = element[0], element[1]
|
||||
|
||||
self.cmat[i0][i1] = 1
|
||||
self.cmat[i1][i0] = 1
|
||||
|
||||
self.age[i0][i1] = 0
|
||||
self.age[i1][i0] = 0
|
||||
|
||||
self.age[i0][self.cmat[i0] == 1] += 1
|
||||
self.age[i1][self.cmat[i1] == 1] += 1
|
||||
|
||||
self.cmat[i0][self.age[i0] > self.agelimit] = 0
|
||||
self.cmat[i1][self.age[i1] > self.agelimit] = 0
|
||||
|
||||
def get_neighbours(self, position):
|
||||
return torch.where(self.cmat[position])
|
||||
|
||||
def add_prototype(self):
|
||||
new_cmat = torch.zeros([dim + 1 for dim in self.cmat.shape])
|
||||
new_cmat[:-1, :-1] = self.cmat
|
||||
self.cmat = new_cmat
|
||||
|
||||
new_age = torch.zeros([dim + 1 for dim in self.age.shape])
|
||||
new_age[:-1, :-1] = self.age
|
||||
self.age = new_age
|
||||
|
||||
def add_connection(self, a, b):
|
||||
self.cmat[a][b] = 1
|
||||
self.cmat[b][a] = 1
|
||||
|
||||
self.age[a][b] = 0
|
||||
self.age[b][a] = 0
|
||||
|
||||
def remove_connection(self, a, b):
|
||||
self.cmat[a][b] = 0
|
||||
self.cmat[b][a] = 0
|
||||
|
||||
self.age[a][b] = 0
|
||||
self.age[b][a] = 0
|
||||
|
||||
def extra_repr(self):
|
||||
return f"agelimit: {self.agelimit}"
|
||||
@ -126,3 +218,40 @@ class NeuralGas(AbstractPrototypeModel):
|
||||
|
||||
self.topology_layer(d)
|
||||
return cost
|
||||
|
||||
|
||||
class GrowingNeuralGas(NeuralGas):
|
||||
def __init__(self, hparams, **kwargs):
|
||||
super().__init__(hparams, **kwargs)
|
||||
|
||||
# defaults
|
||||
self.hparams.setdefault("step_reduction", 0.5)
|
||||
self.hparams.setdefault("insert_reduction", 0.1)
|
||||
self.hparams.setdefault("insert_freq", 10)
|
||||
|
||||
self.register_buffer("errors",
|
||||
torch.zeros(self.hparams.num_prototypes))
|
||||
|
||||
def training_step(self, train_batch, _batch_idx):
|
||||
x = train_batch[0]
|
||||
protos = self.proto_layer()
|
||||
d = self.distance_layer(x, protos)
|
||||
cost, order = self.energy_layer(d)
|
||||
|
||||
winner = order[:, 0]
|
||||
|
||||
mask = torch.zeros_like(d)
|
||||
mask[torch.arange(len(mask)), winner] = 1.0
|
||||
winner_distances = d * mask
|
||||
|
||||
self.errors += torch.sum(winner_distances * winner_distances, dim=0)
|
||||
self.errors *= self.hparams.step_reduction
|
||||
|
||||
self.topology_layer(d)
|
||||
return cost
|
||||
|
||||
def configure_callbacks(self):
|
||||
return [
|
||||
GNGCallback(reduction=self.hparams.insert_reduction,
|
||||
freq=self.hparams.insert_freq)
|
||||
]
|
||||
|
Loading…
Reference in New Issue
Block a user