Update SOM example
This commit is contained in:
parent
4ca846997a
commit
ea33196a50
@ -37,6 +37,7 @@ class Vis2DColorSOM(pl.Callback):
|
|||||||
h, w = pl_module._grid.shape[:2]
|
h, w = pl_module._grid.shape[:2]
|
||||||
protos = pl_module.prototypes.view(h, w, 3)
|
protos = pl_module.prototypes.view(h, w, 3)
|
||||||
ax.imshow(protos)
|
ax.imshow(protos)
|
||||||
|
ax.axis("off")
|
||||||
|
|
||||||
# Overlay color names
|
# Overlay color names
|
||||||
d = pl_module.compute_distances(self.data)
|
d = pl_module.compute_distances(self.data)
|
||||||
@ -49,7 +50,10 @@ class Vis2DColorSOM(pl.Callback):
|
|||||||
va="center",
|
va="center",
|
||||||
bbox=dict(facecolor="white", alpha=0.5, lw=0))
|
bbox=dict(facecolor="white", alpha=0.5, lw=0))
|
||||||
|
|
||||||
plt.pause(self.pause_time)
|
if trainer.current_epoch != trainer.max_epochs - 1:
|
||||||
|
plt.pause(self.pause_time)
|
||||||
|
else:
|
||||||
|
plt.show(block=True)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
@ -65,12 +69,12 @@ if __name__ == "__main__":
|
|||||||
hex_colors = [
|
hex_colors = [
|
||||||
"#000000", "#0000ff", "#00007f", "#1f86ff", "#5466aa", "#997fff",
|
"#000000", "#0000ff", "#00007f", "#1f86ff", "#5466aa", "#997fff",
|
||||||
"#00ff00", "#ff0000", "#00ffff", "#ff00ff", "#ffff00", "#ffffff",
|
"#00ff00", "#ff0000", "#00ffff", "#ff00ff", "#ffff00", "#ffffff",
|
||||||
"#545454", "#7f7f7f", "#a8a8a8"
|
"#545454", "#7f7f7f", "#a8a8a8", "#808000", "#800080", "#ffa500"
|
||||||
]
|
]
|
||||||
cnames = [
|
cnames = [
|
||||||
"black", "blue", "darkblue", "skyblue", "greyblue", "lilac", "green",
|
"black", "blue", "darkblue", "skyblue", "greyblue", "lilac", "green",
|
||||||
"red", "cyan", "violet", "yellow", "white", "darkgrey", "mediumgrey",
|
"red", "cyan", "magenta", "yellow", "white", "darkgrey", "mediumgrey",
|
||||||
"lightgrey"
|
"lightgrey", "olive", "purple", "orange"
|
||||||
]
|
]
|
||||||
colors = list(hex_to_rgb(hex_colors))
|
colors = list(hex_to_rgb(hex_colors))
|
||||||
data = torch.Tensor(colors) / 255.0
|
data = torch.Tensor(colors) / 255.0
|
||||||
@ -81,7 +85,7 @@ if __name__ == "__main__":
|
|||||||
hparams = dict(
|
hparams = dict(
|
||||||
shape=(18, 32),
|
shape=(18, 32),
|
||||||
alpha=1.0,
|
alpha=1.0,
|
||||||
sigma=8,
|
sigma=16,
|
||||||
lr=0.1,
|
lr=0.1,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -103,7 +107,7 @@ if __name__ == "__main__":
|
|||||||
# Setup trainer
|
# Setup trainer
|
||||||
trainer = pl.Trainer.from_argparse_args(
|
trainer = pl.Trainer.from_argparse_args(
|
||||||
args,
|
args,
|
||||||
max_epochs=300,
|
max_epochs=500,
|
||||||
callbacks=[vis],
|
callbacks=[vis],
|
||||||
weights_summary="full",
|
weights_summary="full",
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user