prototorch_models/prototorch/models/unsupervised.py

156 lines
4.7 KiB
Python
Raw Normal View History

2021-05-21 13:42:45 +00:00
"""Unsupervised prototype learning algorithms."""
2021-06-09 16:21:12 +00:00
import numpy as np
2021-04-23 15:30:23 +00:00
import torch
2022-05-16 09:12:53 +00:00
from prototorch.core.competitions import wtac
from prototorch.core.distances import squared_euclidean_distance
from prototorch.core.losses import NeuralGasEnergy
2021-04-23 15:30:23 +00:00
2022-05-17 14:19:47 +00:00
from .abstract import UnsupervisedPrototypeModel
2021-06-04 20:20:32 +00:00
from .callbacks import GNGCallback
from .extras import ConnectionTopology
2022-05-17 14:19:47 +00:00
from .mixins import NonGradientMixin
2021-04-23 15:30:23 +00:00
2021-06-07 16:44:15 +00:00
class KohonenSOM(NonGradientMixin, UnsupervisedPrototypeModel):
"""Kohonen Self-Organizing-Map.
TODO Allow non-2D grids
"""
_grid: torch.Tensor
2021-06-07 16:44:15 +00:00
def __init__(self, hparams, **kwargs):
h, w = hparams.get("shape")
# Ignore `num_prototypes`
hparams["num_prototypes"] = h * w
distance_fn = kwargs.pop("distance_fn", squared_euclidean_distance)
super().__init__(hparams, distance_fn=distance_fn, **kwargs)
# Hyperparameters
self.save_hyperparameters(hparams)
# Default hparams
self.hparams.setdefault("alpha", 0.3)
self.hparams.setdefault("sigma", max(h, w) / 2.0)
# Additional parameters
x, y = torch.arange(h), torch.arange(w)
2022-03-30 13:12:33 +00:00
grid = torch.stack(torch.meshgrid(x, y, indexing="ij"), dim=-1)
2021-06-07 16:44:15 +00:00
self.register_buffer("_grid", grid)
2021-06-09 16:21:12 +00:00
self._sigma = self.hparams.sigma
self._lr = self.hparams.lr
2021-06-07 16:44:15 +00:00
def predict_from_distances(self, distances):
grid = self._grid.view(-1, 2)
wp = wtac(distances, grid)
return wp
def training_step(self, train_batch, batch_idx):
# x = train_batch
# TODO Check if the batch has labels
x = train_batch[0]
d = self.compute_distances(x)
wp = self.predict_from_distances(d)
grid = self._grid.view(-1, 2)
gd = squared_euclidean_distance(wp, grid)
2021-06-09 16:21:12 +00:00
nh = torch.exp(-gd / self._sigma**2)
protos = self.proto_layer()
2021-06-07 16:44:15 +00:00
diff = x.unsqueeze(dim=1) - protos
2021-06-09 16:21:12 +00:00
delta = self._lr * self.hparams.alpha * nh.unsqueeze(-1) * diff
2021-06-07 16:44:15 +00:00
updated_protos = protos + delta.sum(dim=0)
self.proto_layer.load_state_dict(
{"_components": updated_protos},
strict=False,
)
2021-06-07 16:44:15 +00:00
2021-06-09 16:21:12 +00:00
def training_epoch_end(self, training_step_outputs):
self._sigma = self.hparams.sigma * np.exp(
-self.current_epoch / self.trainer.max_epochs)
2021-06-07 16:44:15 +00:00
def extra_repr(self):
return f"(grid): (shape: {tuple(self._grid.shape)})"
class HeskesSOM(UnsupervisedPrototypeModel):
2021-06-07 16:44:15 +00:00
def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs)
def training_step(self, train_batch, batch_idx):
# TODO Implement me!
raise NotImplementedError()
2021-06-04 20:20:32 +00:00
class NeuralGas(UnsupervisedPrototypeModel):
2021-04-23 15:30:23 +00:00
def __init__(self, hparams, **kwargs):
2021-06-04 20:20:32 +00:00
super().__init__(hparams, **kwargs)
2021-04-23 15:30:23 +00:00
2021-06-04 20:20:32 +00:00
# Hyperparameters
2021-04-23 15:30:23 +00:00
self.save_hyperparameters(hparams)
2021-06-04 20:20:32 +00:00
# Default hparams
2022-03-30 13:12:33 +00:00
self.hparams.setdefault("age_limit", 10)
2021-04-23 15:30:23 +00:00
self.hparams.setdefault("lm", 1)
self.energy_layer = NeuralGasEnergy(lm=self.hparams["lm"])
2021-04-23 15:30:23 +00:00
self.topology_layer = ConnectionTopology(
agelimit=self.hparams["age_limit"],
num_prototypes=self.hparams["num_prototypes"],
2021-04-23 15:30:23 +00:00
)
def training_step(self, train_batch, batch_idx):
2021-06-04 20:20:32 +00:00
# x = train_batch
# TODO Check if the batch has labels
x = train_batch[0]
2021-06-04 20:20:32 +00:00
d = self.compute_distances(x)
2021-06-11 16:50:14 +00:00
loss, _ = self.energy_layer(d)
2021-04-23 15:30:23 +00:00
self.topology_layer(d)
2021-06-11 16:50:14 +00:00
self.log("loss", loss)
return loss
2021-06-01 15:19:43 +00:00
class GrowingNeuralGas(NeuralGas):
errors: torch.Tensor
2021-06-01 15:19:43 +00:00
def __init__(self, hparams, **kwargs):
super().__init__(hparams, **kwargs)
2021-06-04 20:20:32 +00:00
# Defaults
2021-06-01 15:19:43 +00:00
self.hparams.setdefault("step_reduction", 0.5)
self.hparams.setdefault("insert_reduction", 0.1)
self.hparams.setdefault("insert_freq", 10)
errors = torch.zeros(
self.hparams["num_prototypes"],
device=self.device,
)
2021-06-04 20:20:32 +00:00
self.register_buffer("errors", errors)
2021-06-01 15:19:43 +00:00
def training_step(self, train_batch, _batch_idx):
2021-06-04 20:20:32 +00:00
# x = train_batch
# TODO Check if the batch has labels
2021-06-01 15:19:43 +00:00
x = train_batch[0]
2021-06-04 20:20:32 +00:00
d = self.compute_distances(x)
2021-06-11 16:50:14 +00:00
loss, order = self.energy_layer(d)
2021-06-01 15:19:43 +00:00
winner = order[:, 0]
mask = torch.zeros_like(d)
mask[torch.arange(len(mask)), winner] = 1.0
2021-06-04 20:20:32 +00:00
dp = d * mask
2021-06-01 15:19:43 +00:00
self.errors += torch.sum(dp * dp)
self.errors *= self.hparams["step_reduction"]
2021-06-01 15:19:43 +00:00
self.topology_layer(d)
2021-06-11 16:50:14 +00:00
self.log("loss", loss)
return loss
2021-06-01 15:19:43 +00:00
def configure_callbacks(self):
return [
GNGCallback(
reduction=self.hparams["insert_reduction"],
freq=self.hparams["insert_freq"],
)
2021-06-01 15:19:43 +00:00
]