Update SOM
This commit is contained in:
parent
022d791ea5
commit
57f8bec270
@ -81,7 +81,7 @@ if __name__ == "__main__":
|
|||||||
hparams = dict(
|
hparams = dict(
|
||||||
shape=(18, 32),
|
shape=(18, 32),
|
||||||
alpha=1.0,
|
alpha=1.0,
|
||||||
sigma=3,
|
sigma=8,
|
||||||
lr=0.1,
|
lr=0.1,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -10,7 +10,12 @@ from prototorch.functions.distances import euclidean_distance
|
|||||||
from prototorch.modules import WTAC, LambdaLayer
|
from prototorch.modules import WTAC, LambdaLayer
|
||||||
|
|
||||||
|
|
||||||
|
class ProtoTorchMixin(object):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class ProtoTorchBolt(pl.LightningModule):
|
class ProtoTorchBolt(pl.LightningModule):
|
||||||
|
"""All ProtoTorch models are ProtoTorch Bolts."""
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
surep = super().__repr__()
|
surep = super().__repr__()
|
||||||
indented = "".join([f"\t{line}\n" for line in surep.splitlines()])
|
indented = "".join([f"\t{line}\n" for line in surep.splitlines()])
|
||||||
@ -160,7 +165,7 @@ class SupervisedPrototypeModel(PrototypeModel):
|
|||||||
logger=True)
|
logger=True)
|
||||||
|
|
||||||
|
|
||||||
class NonGradientMixin():
|
class NonGradientMixin(ProtoTorchMixin):
|
||||||
"""Mixin for custom non-gradient optimization."""
|
"""Mixin for custom non-gradient optimization."""
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
@ -170,7 +175,7 @@ class NonGradientMixin():
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
class ImagePrototypesMixin(ProtoTorchBolt):
|
class ImagePrototypesMixin(ProtoTorchMixin):
|
||||||
"""Mixin for models with image prototypes."""
|
"""Mixin for models with image prototypes."""
|
||||||
@final
|
@final
|
||||||
def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
|
def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
|
||||||
|
@ -1,9 +1,9 @@
|
|||||||
"""Unsupervised prototype learning algorithms."""
|
"""Unsupervised prototype learning algorithms."""
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from prototorch.functions.competitions import wtac
|
from prototorch.functions.competitions import wtac
|
||||||
from prototorch.functions.distances import squared_euclidean_distance
|
from prototorch.functions.distances import squared_euclidean_distance
|
||||||
from prototorch.functions.helper import get_flat
|
|
||||||
from prototorch.modules import LambdaLayer
|
from prototorch.modules import LambdaLayer
|
||||||
from prototorch.modules.losses import NeuralGasEnergy
|
from prototorch.modules.losses import NeuralGasEnergy
|
||||||
|
|
||||||
@ -36,6 +36,8 @@ class KohonenSOM(NonGradientMixin, UnsupervisedPrototypeModel):
|
|||||||
x, y = torch.arange(h), torch.arange(w)
|
x, y = torch.arange(h), torch.arange(w)
|
||||||
grid = torch.stack(torch.meshgrid(x, y), dim=-1)
|
grid = torch.stack(torch.meshgrid(x, y), dim=-1)
|
||||||
self.register_buffer("_grid", grid)
|
self.register_buffer("_grid", grid)
|
||||||
|
self._sigma = self.hparams.sigma
|
||||||
|
self._lr = self.hparams.lr
|
||||||
|
|
||||||
def predict_from_distances(self, distances):
|
def predict_from_distances(self, distances):
|
||||||
grid = self._grid.view(-1, 2)
|
grid = self._grid.view(-1, 2)
|
||||||
@ -50,14 +52,18 @@ class KohonenSOM(NonGradientMixin, UnsupervisedPrototypeModel):
|
|||||||
wp = self.predict_from_distances(d)
|
wp = self.predict_from_distances(d)
|
||||||
grid = self._grid.view(-1, 2)
|
grid = self._grid.view(-1, 2)
|
||||||
gd = squared_euclidean_distance(wp, grid)
|
gd = squared_euclidean_distance(wp, grid)
|
||||||
nh = torch.exp(-gd / self.hparams.sigma**2)
|
nh = torch.exp(-gd / self._sigma**2)
|
||||||
protos = self.proto_layer.components
|
protos = self.proto_layer.components
|
||||||
diff = x.unsqueeze(dim=1) - protos
|
diff = x.unsqueeze(dim=1) - protos
|
||||||
delta = self.hparams.lr * self.hparams.alpha * nh.unsqueeze(-1) * diff
|
delta = self._lr * self.hparams.alpha * nh.unsqueeze(-1) * diff
|
||||||
updated_protos = protos + delta.sum(dim=0)
|
updated_protos = protos + delta.sum(dim=0)
|
||||||
self.proto_layer.load_state_dict({"_components": updated_protos},
|
self.proto_layer.load_state_dict({"_components": updated_protos},
|
||||||
strict=False)
|
strict=False)
|
||||||
|
|
||||||
|
def training_epoch_end(self, training_step_outputs):
|
||||||
|
self._sigma = self.hparams.sigma * np.exp(
|
||||||
|
-self.current_epoch / self.trainer.max_epochs)
|
||||||
|
|
||||||
def extra_repr(self):
|
def extra_repr(self):
|
||||||
return f"(grid): (shape: {tuple(self._grid.shape)})"
|
return f"(grid): (shape: {tuple(self._grid.shape)})"
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user