Update SOM
This commit is contained in:
parent
022d791ea5
commit
57f8bec270
@ -81,7 +81,7 @@ if __name__ == "__main__":
|
||||
hparams = dict(
|
||||
shape=(18, 32),
|
||||
alpha=1.0,
|
||||
sigma=3,
|
||||
sigma=8,
|
||||
lr=0.1,
|
||||
)
|
||||
|
||||
|
@ -10,7 +10,12 @@ from prototorch.functions.distances import euclidean_distance
|
||||
from prototorch.modules import WTAC, LambdaLayer
|
||||
|
||||
|
||||
class ProtoTorchMixin(object):
|
||||
pass
|
||||
|
||||
|
||||
class ProtoTorchBolt(pl.LightningModule):
|
||||
"""All ProtoTorch models are ProtoTorch Bolts."""
|
||||
def __repr__(self):
|
||||
surep = super().__repr__()
|
||||
indented = "".join([f"\t{line}\n" for line in surep.splitlines()])
|
||||
@ -160,7 +165,7 @@ class SupervisedPrototypeModel(PrototypeModel):
|
||||
logger=True)
|
||||
|
||||
|
||||
class NonGradientMixin():
|
||||
class NonGradientMixin(ProtoTorchMixin):
|
||||
"""Mixin for custom non-gradient optimization."""
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
@ -170,7 +175,7 @@ class NonGradientMixin():
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class ImagePrototypesMixin(ProtoTorchBolt):
|
||||
class ImagePrototypesMixin(ProtoTorchMixin):
|
||||
"""Mixin for models with image prototypes."""
|
||||
@final
|
||||
def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
|
||||
|
@ -1,9 +1,9 @@
|
||||
"""Unsupervised prototype learning algorithms."""
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from prototorch.functions.competitions import wtac
|
||||
from prototorch.functions.distances import squared_euclidean_distance
|
||||
from prototorch.functions.helper import get_flat
|
||||
from prototorch.modules import LambdaLayer
|
||||
from prototorch.modules.losses import NeuralGasEnergy
|
||||
|
||||
@ -36,6 +36,8 @@ class KohonenSOM(NonGradientMixin, UnsupervisedPrototypeModel):
|
||||
x, y = torch.arange(h), torch.arange(w)
|
||||
grid = torch.stack(torch.meshgrid(x, y), dim=-1)
|
||||
self.register_buffer("_grid", grid)
|
||||
self._sigma = self.hparams.sigma
|
||||
self._lr = self.hparams.lr
|
||||
|
||||
def predict_from_distances(self, distances):
|
||||
grid = self._grid.view(-1, 2)
|
||||
@ -50,14 +52,18 @@ class KohonenSOM(NonGradientMixin, UnsupervisedPrototypeModel):
|
||||
wp = self.predict_from_distances(d)
|
||||
grid = self._grid.view(-1, 2)
|
||||
gd = squared_euclidean_distance(wp, grid)
|
||||
nh = torch.exp(-gd / self.hparams.sigma**2)
|
||||
nh = torch.exp(-gd / self._sigma**2)
|
||||
protos = self.proto_layer.components
|
||||
diff = x.unsqueeze(dim=1) - protos
|
||||
delta = self.hparams.lr * self.hparams.alpha * nh.unsqueeze(-1) * diff
|
||||
delta = self._lr * self.hparams.alpha * nh.unsqueeze(-1) * diff
|
||||
updated_protos = protos + delta.sum(dim=0)
|
||||
self.proto_layer.load_state_dict({"_components": updated_protos},
|
||||
strict=False)
|
||||
|
||||
def training_epoch_end(self, training_step_outputs):
|
||||
self._sigma = self.hparams.sigma * np.exp(
|
||||
-self.current_epoch / self.trainer.max_epochs)
|
||||
|
||||
def extra_repr(self):
|
||||
return f"(grid): (shape: {tuple(self._grid.shape)})"
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user