Add border argument in visualization callback
This commit is contained in:
		@@ -267,6 +267,7 @@ class Vis2DAbstract(pl.Callback):
 | 
				
			|||||||
                 y_train,
 | 
					                 y_train,
 | 
				
			||||||
                 title="Prototype Visualization",
 | 
					                 title="Prototype Visualization",
 | 
				
			||||||
                 cmap="viridis",
 | 
					                 cmap="viridis",
 | 
				
			||||||
 | 
					                 border=1,
 | 
				
			||||||
                 show_last_only=False,
 | 
					                 show_last_only=False,
 | 
				
			||||||
                 block=False):
 | 
					                 block=False):
 | 
				
			||||||
        super().__init__()
 | 
					        super().__init__()
 | 
				
			||||||
@@ -275,6 +276,7 @@ class Vis2DAbstract(pl.Callback):
 | 
				
			|||||||
        self.title = title
 | 
					        self.title = title
 | 
				
			||||||
        self.fig = plt.figure(self.title)
 | 
					        self.fig = plt.figure(self.title)
 | 
				
			||||||
        self.cmap = cmap
 | 
					        self.cmap = cmap
 | 
				
			||||||
 | 
					        self.border = border
 | 
				
			||||||
        self.show_last_only = show_last_only
 | 
					        self.show_last_only = show_last_only
 | 
				
			||||||
        self.block = block
 | 
					        self.block = block
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -343,8 +345,8 @@ class VisSiameseGLVQ2D(Vis2DAbstract):
 | 
				
			|||||||
            s=50,
 | 
					            s=50,
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        x = np.vstack((x_train, protos))
 | 
					        x = np.vstack((x_train, protos))
 | 
				
			||||||
        x_min, x_max = x[:, 0].min() - 1, x[:, 0].max() + 1
 | 
					        x_min, x_max = x[:, 0].min() - self.border, x[:, 0].max() + self.border
 | 
				
			||||||
        y_min, y_max = x[:, 1].min() - 1, x[:, 1].max() + 1
 | 
					        y_min, y_max = x[:, 1].min() - self.border, x[:, 1].max() + self.border
 | 
				
			||||||
        xx, yy = np.meshgrid(np.arange(x_min, x_max, 1 / 50),
 | 
					        xx, yy = np.meshgrid(np.arange(x_min, x_max, 1 / 50),
 | 
				
			||||||
                             np.arange(y_min, y_max, 1 / 50))
 | 
					                             np.arange(y_min, y_max, 1 / 50))
 | 
				
			||||||
        mesh_input = np.c_[xx.ravel(), yy.ravel()]
 | 
					        mesh_input = np.c_[xx.ravel(), yy.ravel()]
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user