Add tensorboard argument to visualization callbacks
This commit is contained in:
		@@ -268,6 +268,7 @@ class Vis2DAbstract(pl.Callback):
 | 
				
			|||||||
                 title="Prototype Visualization",
 | 
					                 title="Prototype Visualization",
 | 
				
			||||||
                 cmap="viridis",
 | 
					                 cmap="viridis",
 | 
				
			||||||
                 border=1,
 | 
					                 border=1,
 | 
				
			||||||
 | 
					                 tensorboard=False,
 | 
				
			||||||
                 show_last_only=False,
 | 
					                 show_last_only=False,
 | 
				
			||||||
                 block=False):
 | 
					                 block=False):
 | 
				
			||||||
        super().__init__()
 | 
					        super().__init__()
 | 
				
			||||||
@@ -277,9 +278,17 @@ class Vis2DAbstract(pl.Callback):
 | 
				
			|||||||
        self.fig = plt.figure(self.title)
 | 
					        self.fig = plt.figure(self.title)
 | 
				
			||||||
        self.cmap = cmap
 | 
					        self.cmap = cmap
 | 
				
			||||||
        self.border = border
 | 
					        self.border = border
 | 
				
			||||||
 | 
					        self.tensorboard = tensorboard
 | 
				
			||||||
        self.show_last_only = show_last_only
 | 
					        self.show_last_only = show_last_only
 | 
				
			||||||
        self.block = block
 | 
					        self.block = block
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def add_to_tensorboard(self, trainer, pl_module):
 | 
				
			||||||
 | 
					        tb = pl_module.logger.experiment
 | 
				
			||||||
 | 
					        tb.add_figure(tag=f"{self.title}",
 | 
				
			||||||
 | 
					                      figure=self.fig,
 | 
				
			||||||
 | 
					                      global_step=trainer.current_epoch,
 | 
				
			||||||
 | 
					                      close=False)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class VisGLVQ2D(Vis2DAbstract):
 | 
					class VisGLVQ2D(Vis2DAbstract):
 | 
				
			||||||
    def on_epoch_end(self, trainer, pl_module):
 | 
					    def on_epoch_end(self, trainer, pl_module):
 | 
				
			||||||
@@ -317,6 +326,8 @@ class VisGLVQ2D(Vis2DAbstract):
 | 
				
			|||||||
        ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
 | 
					        ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
 | 
				
			||||||
        ax.set_xlim(left=x_min + 0, right=x_max - 0)
 | 
					        ax.set_xlim(left=x_min + 0, right=x_max - 0)
 | 
				
			||||||
        ax.set_ylim(bottom=y_min + 0, top=y_max - 0)
 | 
					        ax.set_ylim(bottom=y_min + 0, top=y_max - 0)
 | 
				
			||||||
 | 
					        if self.tensorboard:
 | 
				
			||||||
 | 
					            self.add_to_tensorboard(trainer, pl_module)
 | 
				
			||||||
        if not self.block:
 | 
					        if not self.block:
 | 
				
			||||||
            plt.pause(0.01)
 | 
					            plt.pause(0.01)
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
@@ -364,6 +375,8 @@ class VisSiameseGLVQ2D(Vis2DAbstract):
 | 
				
			|||||||
            close=False,
 | 
					            close=False,
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if self.tensorboard:
 | 
				
			||||||
 | 
					            self.add_to_tensorboard(trainer, pl_module)
 | 
				
			||||||
        if not self.block:
 | 
					        if not self.block:
 | 
				
			||||||
            plt.pause(0.01)
 | 
					            plt.pause(0.01)
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
@@ -404,6 +417,8 @@ class VisNG2D(Vis2DAbstract):
 | 
				
			|||||||
                        "k-",
 | 
					                        "k-",
 | 
				
			||||||
                    )
 | 
					                    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if self.tensorboard:
 | 
				
			||||||
 | 
					            self.add_to_tensorboard(trainer, pl_module)
 | 
				
			||||||
        if not self.block:
 | 
					        if not self.block:
 | 
				
			||||||
            plt.pause(0.01)
 | 
					            plt.pause(0.01)
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user