Update SOM example

This commit is contained in:
Jensun Ravichandran 2021-06-11 11:29:47 +02:00
parent 4ca846997a
commit ea33196a50

View File

@ -37,6 +37,7 @@ class Vis2DColorSOM(pl.Callback):
h, w = pl_module._grid.shape[:2]
protos = pl_module.prototypes.view(h, w, 3)
ax.imshow(protos)
ax.axis("off")
# Overlay color names
d = pl_module.compute_distances(self.data)
@ -49,7 +50,10 @@ class Vis2DColorSOM(pl.Callback):
va="center",
bbox=dict(facecolor="white", alpha=0.5, lw=0))
plt.pause(self.pause_time)
if trainer.current_epoch != trainer.max_epochs - 1:
plt.pause(self.pause_time)
else:
plt.show(block=True)
if __name__ == "__main__":
@ -65,12 +69,12 @@ if __name__ == "__main__":
hex_colors = [
"#000000", "#0000ff", "#00007f", "#1f86ff", "#5466aa", "#997fff",
"#00ff00", "#ff0000", "#00ffff", "#ff00ff", "#ffff00", "#ffffff",
"#545454", "#7f7f7f", "#a8a8a8"
"#545454", "#7f7f7f", "#a8a8a8", "#808000", "#800080", "#ffa500"
]
cnames = [
"black", "blue", "darkblue", "skyblue", "greyblue", "lilac", "green",
"red", "cyan", "violet", "yellow", "white", "darkgrey", "mediumgrey",
"lightgrey"
"red", "cyan", "magenta", "yellow", "white", "darkgrey", "mediumgrey",
"lightgrey", "olive", "purple", "orange"
]
colors = list(hex_to_rgb(hex_colors))
data = torch.Tensor(colors) / 255.0
@ -81,7 +85,7 @@ if __name__ == "__main__":
hparams = dict(
shape=(18, 32),
alpha=1.0,
sigma=8,
sigma=16,
lr=0.1,
)
@ -103,7 +107,7 @@ if __name__ == "__main__":
# Setup trainer
trainer = pl.Trainer.from_argparse_args(
args,
max_epochs=300,
max_epochs=500,
callbacks=[vis],
weights_summary="full",
)