Update SOM

This commit is contained in:
Jensun Ravichandran 2021-06-09 18:21:12 +02:00
parent 022d791ea5
commit 57f8bec270
3 changed files with 17 additions and 6 deletions

View File

@ -81,7 +81,7 @@ if __name__ == "__main__":
hparams = dict( hparams = dict(
shape=(18, 32), shape=(18, 32),
alpha=1.0, alpha=1.0,
sigma=3, sigma=8,
lr=0.1, lr=0.1,
) )

View File

@ -10,7 +10,12 @@ from prototorch.functions.distances import euclidean_distance
from prototorch.modules import WTAC, LambdaLayer from prototorch.modules import WTAC, LambdaLayer
class ProtoTorchMixin(object):
pass
class ProtoTorchBolt(pl.LightningModule): class ProtoTorchBolt(pl.LightningModule):
"""All ProtoTorch models are ProtoTorch Bolts."""
def __repr__(self): def __repr__(self):
surep = super().__repr__() surep = super().__repr__()
indented = "".join([f"\t{line}\n" for line in surep.splitlines()]) indented = "".join([f"\t{line}\n" for line in surep.splitlines()])
@ -160,7 +165,7 @@ class SupervisedPrototypeModel(PrototypeModel):
logger=True) logger=True)
class NonGradientMixin(): class NonGradientMixin(ProtoTorchMixin):
"""Mixin for custom non-gradient optimization.""" """Mixin for custom non-gradient optimization."""
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
@ -170,7 +175,7 @@ class NonGradientMixin():
raise NotImplementedError raise NotImplementedError
class ImagePrototypesMixin(ProtoTorchBolt): class ImagePrototypesMixin(ProtoTorchMixin):
"""Mixin for models with image prototypes.""" """Mixin for models with image prototypes."""
@final @final
def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx): def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):

View File

@ -1,9 +1,9 @@
"""Unsupervised prototype learning algorithms.""" """Unsupervised prototype learning algorithms."""
import numpy as np
import torch import torch
from prototorch.functions.competitions import wtac from prototorch.functions.competitions import wtac
from prototorch.functions.distances import squared_euclidean_distance from prototorch.functions.distances import squared_euclidean_distance
from prototorch.functions.helper import get_flat
from prototorch.modules import LambdaLayer from prototorch.modules import LambdaLayer
from prototorch.modules.losses import NeuralGasEnergy from prototorch.modules.losses import NeuralGasEnergy
@ -36,6 +36,8 @@ class KohonenSOM(NonGradientMixin, UnsupervisedPrototypeModel):
x, y = torch.arange(h), torch.arange(w) x, y = torch.arange(h), torch.arange(w)
grid = torch.stack(torch.meshgrid(x, y), dim=-1) grid = torch.stack(torch.meshgrid(x, y), dim=-1)
self.register_buffer("_grid", grid) self.register_buffer("_grid", grid)
self._sigma = self.hparams.sigma
self._lr = self.hparams.lr
def predict_from_distances(self, distances): def predict_from_distances(self, distances):
grid = self._grid.view(-1, 2) grid = self._grid.view(-1, 2)
@ -50,14 +52,18 @@ class KohonenSOM(NonGradientMixin, UnsupervisedPrototypeModel):
wp = self.predict_from_distances(d) wp = self.predict_from_distances(d)
grid = self._grid.view(-1, 2) grid = self._grid.view(-1, 2)
gd = squared_euclidean_distance(wp, grid) gd = squared_euclidean_distance(wp, grid)
nh = torch.exp(-gd / self.hparams.sigma**2) nh = torch.exp(-gd / self._sigma**2)
protos = self.proto_layer.components protos = self.proto_layer.components
diff = x.unsqueeze(dim=1) - protos diff = x.unsqueeze(dim=1) - protos
delta = self.hparams.lr * self.hparams.alpha * nh.unsqueeze(-1) * diff delta = self._lr * self.hparams.alpha * nh.unsqueeze(-1) * diff
updated_protos = protos + delta.sum(dim=0) updated_protos = protos + delta.sum(dim=0)
self.proto_layer.load_state_dict({"_components": updated_protos}, self.proto_layer.load_state_dict({"_components": updated_protos},
strict=False) strict=False)
def training_epoch_end(self, training_step_outputs):
self._sigma = self.hparams.sigma * np.exp(
-self.current_epoch / self.trainer.max_epochs)
def extra_repr(self): def extra_repr(self):
return f"(grid): (shape: {tuple(self._grid.shape)})" return f"(grid): (shape: {tuple(self._grid.shape)})"