Update SOM example

This commit is contained in:
Jensun Ravichandran 2021-06-11 11:29:47 +02:00
parent 4ca846997a
commit ea33196a50

View File

@ -37,6 +37,7 @@ class Vis2DColorSOM(pl.Callback):
h, w = pl_module._grid.shape[:2] h, w = pl_module._grid.shape[:2]
protos = pl_module.prototypes.view(h, w, 3) protos = pl_module.prototypes.view(h, w, 3)
ax.imshow(protos) ax.imshow(protos)
ax.axis("off")
# Overlay color names # Overlay color names
d = pl_module.compute_distances(self.data) d = pl_module.compute_distances(self.data)
@ -49,7 +50,10 @@ class Vis2DColorSOM(pl.Callback):
va="center", va="center",
bbox=dict(facecolor="white", alpha=0.5, lw=0)) bbox=dict(facecolor="white", alpha=0.5, lw=0))
plt.pause(self.pause_time) if trainer.current_epoch != trainer.max_epochs - 1:
plt.pause(self.pause_time)
else:
plt.show(block=True)
if __name__ == "__main__": if __name__ == "__main__":
@ -65,12 +69,12 @@ if __name__ == "__main__":
hex_colors = [ hex_colors = [
"#000000", "#0000ff", "#00007f", "#1f86ff", "#5466aa", "#997fff", "#000000", "#0000ff", "#00007f", "#1f86ff", "#5466aa", "#997fff",
"#00ff00", "#ff0000", "#00ffff", "#ff00ff", "#ffff00", "#ffffff", "#00ff00", "#ff0000", "#00ffff", "#ff00ff", "#ffff00", "#ffffff",
"#545454", "#7f7f7f", "#a8a8a8" "#545454", "#7f7f7f", "#a8a8a8", "#808000", "#800080", "#ffa500"
] ]
cnames = [ cnames = [
"black", "blue", "darkblue", "skyblue", "greyblue", "lilac", "green", "black", "blue", "darkblue", "skyblue", "greyblue", "lilac", "green",
"red", "cyan", "violet", "yellow", "white", "darkgrey", "mediumgrey", "red", "cyan", "magenta", "yellow", "white", "darkgrey", "mediumgrey",
"lightgrey" "lightgrey", "olive", "purple", "orange"
] ]
colors = list(hex_to_rgb(hex_colors)) colors = list(hex_to_rgb(hex_colors))
data = torch.Tensor(colors) / 255.0 data = torch.Tensor(colors) / 255.0
@ -81,7 +85,7 @@ if __name__ == "__main__":
hparams = dict( hparams = dict(
shape=(18, 32), shape=(18, 32),
alpha=1.0, alpha=1.0,
sigma=8, sigma=16,
lr=0.1, lr=0.1,
) )
@ -103,7 +107,7 @@ if __name__ == "__main__":
# Setup trainer # Setup trainer
trainer = pl.Trainer.from_argparse_args( trainer = pl.Trainer.from_argparse_args(
args, args,
max_epochs=300, max_epochs=500,
callbacks=[vis], callbacks=[vis],
weights_summary="full", weights_summary="full",
) )