diff --git a/README.md b/README.md index 08f5938..15e71c3 100644 --- a/README.md +++ b/README.md @@ -44,6 +44,12 @@ To assist in the development process, you may also find it useful to install ## Available models - [X] GLVQ +- [X] Neural Gas + +## Work in Progress +- [ ] CBC + +## Planned models - [ ] GMLVQ - [ ] Local-Matrix GMLVQ - [ ] Limited-Rank GMLVQ @@ -51,4 +57,3 @@ To assist in the development process, you may also find it useful to install - [ ] RSLVQ - [ ] PLVQ - [ ] LVQMLN -- [ ] CBC diff --git a/examples/ng_iris.py b/examples/ng_iris.py new file mode 100644 index 0000000..7c1bd18 --- /dev/null +++ b/examples/ng_iris.py @@ -0,0 +1,104 @@ +"""CBC example using the Iris dataset.""" + +import numpy as np +import pytorch_lightning as pl +from matplotlib import pyplot as plt +from sklearn.datasets import load_iris +from sklearn.preprocessing import StandardScaler +from torch.utils.data import DataLoader + +from prototorch.datasets.abstract import NumpyDataset +from prototorch.models.neural_gas import NeuralGas + + +class VisualizationCallback(pl.Callback): + def __init__(self, + x_train, + y_train, + title="Neural Gas Visualization", + cmap="viridis"): + super().__init__() + self.x_train = x_train + self.y_train = y_train + self.title = title + self.fig = plt.figure(self.title) + self.cmap = cmap + + def on_epoch_end(self, trainer, pl_module: NeuralGas): + protos = pl_module.proto_layer.prototypes.detach().cpu().numpy() + cmat = pl_module.topology_layer.cmat.cpu().numpy() + + # Visualize the data and the prototypes + ax = self.fig.gca() + ax.cla() + ax.set_title(self.title) + ax.set_xlabel("Data dimension 1") + ax.set_ylabel("Data dimension 2") + ax.scatter(self.x_train[:, 0], + self.x_train[:, 1], + c=self.y_train, + edgecolor="k") + ax.scatter( + protos[:, 0], + protos[:, 1], + c="k", + edgecolor="k", + marker="D", + s=50, + ) + + # Draw connections + for i in range(len(protos)): + for j in range(len(protos)): + if cmat[i][j]: + ax.plot( + [protos[i, 0], protos[j, 0]], + [protos[i, 1], protos[j, 1]], + "k-", + ) + + plt.pause(0.01) + + +if __name__ == "__main__": + # Dataset + x_train, y_train = load_iris(return_X_y=True) + x_train = x_train[:, [0, 2]] + scaler = StandardScaler() + scaler.fit(x_train) + x_train = scaler.transform(x_train) + + y_single_class = np.zeros_like(y_train) + train_ds = NumpyDataset(x_train, y_train) + + # Dataloaders + train_loader = DataLoader(train_ds, num_workers=0, batch_size=150) + + # Hyperparameters + hparams = dict( + input_dim=x_train.shape[1], + nclasses=1, + prototypes_per_class=30, + prototype_initializer="rand", + lr=0.01, + ) + + # Initialize the model + model = NeuralGas(hparams, data=[x_train, y_single_class]) + + # Model summary + print(model) + + # Callbacks + vis = VisualizationCallback(x_train, y_train) + + # Setup trainer + trainer = pl.Trainer( + max_epochs=100, + callbacks=[ + vis, + ], + ) + + # Training loop + trainer.fit(model, train_loader) diff --git a/prototorch/models/neural_gas.py b/prototorch/models/neural_gas.py new file mode 100644 index 0000000..98bb5e7 --- /dev/null +++ b/prototorch/models/neural_gas.py @@ -0,0 +1,74 @@ +import pytorch_lightning as pl +import torch + +from prototorch.functions.distances import euclidean_distance +from prototorch.modules import Prototypes1D +from prototorch.modules.losses import NeuralGasEnergy + + +class EuclideanDistance(torch.nn.Module): + def forward(self, x, y): + return euclidean_distance(x, y) + + +class ConnectionTopology(torch.nn.Module): + def __init__(self, agelimit, num_prototypes): + super().__init__() + self.agelimit = agelimit + self.num_prototypes = num_prototypes + + self.cmat = torch.zeros((self.num_prototypes, self.num_prototypes)) + self.age = torch.zeros_like(self.cmat) + + def forward(self, d): + order = torch.argsort(d, dim=1) + + for element in order: + i0, i1 = element[0], element[1] + self.cmat[i0][i1] = 1 + self.age[i0][i1] = 0 + self.age[i0][self.cmat[i0] == 1] += 1 + self.cmat[i0][self.age[i0] > self.agelimit] = 0 + + def extra_repr(self): + return f"agelimit: {self.agelimit}" + + +class NeuralGas(pl.LightningModule): + def __init__(self, hparams, **kwargs): + super().__init__() + + self.save_hyperparameters(hparams) + + # Default Values + self.hparams.setdefault("agelimit", 10) + self.hparams.setdefault("lm", 1) + self.hparams.setdefault("prototype_initializer", "zeros") + + self.proto_layer = Prototypes1D( + input_dim=self.hparams.input_dim, + nclasses=self.hparams.nclasses, + prototypes_per_class=self.hparams.prototypes_per_class, + prototype_initializer=self.hparams.prototype_initializer, + **kwargs, + ) + + self.distance_layer = EuclideanDistance() + self.energy_layer = NeuralGasEnergy(lm=self.hparams.lm) + self.topology_layer = ConnectionTopology( + agelimit=self.hparams.agelimit, + num_prototypes=len(self.proto_layer.prototypes), + ) + + def training_step(self, train_batch, batch_idx): + x, _ = train_batch + protos, _ = self.proto_layer() + d = self.distance_layer(x, protos) + cost, order = self.energy_layer(d) + + self.topology_layer(d) + return cost + + def configure_optimizers(self): + optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.lr) + return optimizer