Update SOM example
This commit is contained in:
parent
4ca846997a
commit
ea33196a50
@ -37,6 +37,7 @@ class Vis2DColorSOM(pl.Callback):
|
||||
h, w = pl_module._grid.shape[:2]
|
||||
protos = pl_module.prototypes.view(h, w, 3)
|
||||
ax.imshow(protos)
|
||||
ax.axis("off")
|
||||
|
||||
# Overlay color names
|
||||
d = pl_module.compute_distances(self.data)
|
||||
@ -49,7 +50,10 @@ class Vis2DColorSOM(pl.Callback):
|
||||
va="center",
|
||||
bbox=dict(facecolor="white", alpha=0.5, lw=0))
|
||||
|
||||
plt.pause(self.pause_time)
|
||||
if trainer.current_epoch != trainer.max_epochs - 1:
|
||||
plt.pause(self.pause_time)
|
||||
else:
|
||||
plt.show(block=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@ -65,12 +69,12 @@ if __name__ == "__main__":
|
||||
hex_colors = [
|
||||
"#000000", "#0000ff", "#00007f", "#1f86ff", "#5466aa", "#997fff",
|
||||
"#00ff00", "#ff0000", "#00ffff", "#ff00ff", "#ffff00", "#ffffff",
|
||||
"#545454", "#7f7f7f", "#a8a8a8"
|
||||
"#545454", "#7f7f7f", "#a8a8a8", "#808000", "#800080", "#ffa500"
|
||||
]
|
||||
cnames = [
|
||||
"black", "blue", "darkblue", "skyblue", "greyblue", "lilac", "green",
|
||||
"red", "cyan", "violet", "yellow", "white", "darkgrey", "mediumgrey",
|
||||
"lightgrey"
|
||||
"red", "cyan", "magenta", "yellow", "white", "darkgrey", "mediumgrey",
|
||||
"lightgrey", "olive", "purple", "orange"
|
||||
]
|
||||
colors = list(hex_to_rgb(hex_colors))
|
||||
data = torch.Tensor(colors) / 255.0
|
||||
@ -81,7 +85,7 @@ if __name__ == "__main__":
|
||||
hparams = dict(
|
||||
shape=(18, 32),
|
||||
alpha=1.0,
|
||||
sigma=8,
|
||||
sigma=16,
|
||||
lr=0.1,
|
||||
)
|
||||
|
||||
@ -103,7 +107,7 @@ if __name__ == "__main__":
|
||||
# Setup trainer
|
||||
trainer = pl.Trainer.from_argparse_args(
|
||||
args,
|
||||
max_epochs=300,
|
||||
max_epochs=500,
|
||||
callbacks=[vis],
|
||||
weights_summary="full",
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user