Update SOM example
This commit is contained in:
		@@ -37,6 +37,7 @@ class Vis2DColorSOM(pl.Callback):
 | 
				
			|||||||
        h, w = pl_module._grid.shape[:2]
 | 
					        h, w = pl_module._grid.shape[:2]
 | 
				
			||||||
        protos = pl_module.prototypes.view(h, w, 3)
 | 
					        protos = pl_module.prototypes.view(h, w, 3)
 | 
				
			||||||
        ax.imshow(protos)
 | 
					        ax.imshow(protos)
 | 
				
			||||||
 | 
					        ax.axis("off")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # Overlay color names
 | 
					        # Overlay color names
 | 
				
			||||||
        d = pl_module.compute_distances(self.data)
 | 
					        d = pl_module.compute_distances(self.data)
 | 
				
			||||||
@@ -49,7 +50,10 @@ class Vis2DColorSOM(pl.Callback):
 | 
				
			|||||||
                     va="center",
 | 
					                     va="center",
 | 
				
			||||||
                     bbox=dict(facecolor="white", alpha=0.5, lw=0))
 | 
					                     bbox=dict(facecolor="white", alpha=0.5, lw=0))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        plt.pause(self.pause_time)
 | 
					        if trainer.current_epoch != trainer.max_epochs - 1:
 | 
				
			||||||
 | 
					            plt.pause(self.pause_time)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            plt.show(block=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
if __name__ == "__main__":
 | 
					if __name__ == "__main__":
 | 
				
			||||||
@@ -65,12 +69,12 @@ if __name__ == "__main__":
 | 
				
			|||||||
    hex_colors = [
 | 
					    hex_colors = [
 | 
				
			||||||
        "#000000", "#0000ff", "#00007f", "#1f86ff", "#5466aa", "#997fff",
 | 
					        "#000000", "#0000ff", "#00007f", "#1f86ff", "#5466aa", "#997fff",
 | 
				
			||||||
        "#00ff00", "#ff0000", "#00ffff", "#ff00ff", "#ffff00", "#ffffff",
 | 
					        "#00ff00", "#ff0000", "#00ffff", "#ff00ff", "#ffff00", "#ffffff",
 | 
				
			||||||
        "#545454", "#7f7f7f", "#a8a8a8"
 | 
					        "#545454", "#7f7f7f", "#a8a8a8", "#808000", "#800080", "#ffa500"
 | 
				
			||||||
    ]
 | 
					    ]
 | 
				
			||||||
    cnames = [
 | 
					    cnames = [
 | 
				
			||||||
        "black", "blue", "darkblue", "skyblue", "greyblue", "lilac", "green",
 | 
					        "black", "blue", "darkblue", "skyblue", "greyblue", "lilac", "green",
 | 
				
			||||||
        "red", "cyan", "violet", "yellow", "white", "darkgrey", "mediumgrey",
 | 
					        "red", "cyan", "magenta", "yellow", "white", "darkgrey", "mediumgrey",
 | 
				
			||||||
        "lightgrey"
 | 
					        "lightgrey", "olive", "purple", "orange"
 | 
				
			||||||
    ]
 | 
					    ]
 | 
				
			||||||
    colors = list(hex_to_rgb(hex_colors))
 | 
					    colors = list(hex_to_rgb(hex_colors))
 | 
				
			||||||
    data = torch.Tensor(colors) / 255.0
 | 
					    data = torch.Tensor(colors) / 255.0
 | 
				
			||||||
@@ -81,7 +85,7 @@ if __name__ == "__main__":
 | 
				
			|||||||
    hparams = dict(
 | 
					    hparams = dict(
 | 
				
			||||||
        shape=(18, 32),
 | 
					        shape=(18, 32),
 | 
				
			||||||
        alpha=1.0,
 | 
					        alpha=1.0,
 | 
				
			||||||
        sigma=8,
 | 
					        sigma=16,
 | 
				
			||||||
        lr=0.1,
 | 
					        lr=0.1,
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -103,7 +107,7 @@ if __name__ == "__main__":
 | 
				
			|||||||
    # Setup trainer
 | 
					    # Setup trainer
 | 
				
			||||||
    trainer = pl.Trainer.from_argparse_args(
 | 
					    trainer = pl.Trainer.from_argparse_args(
 | 
				
			||||||
        args,
 | 
					        args,
 | 
				
			||||||
        max_epochs=300,
 | 
					        max_epochs=500,
 | 
				
			||||||
        callbacks=[vis],
 | 
					        callbacks=[vis],
 | 
				
			||||||
        weights_summary="full",
 | 
					        weights_summary="full",
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user