feat: data argument optional in some visualizers
				
					
				
			This commit is contained in:
		@@ -14,7 +14,7 @@ from ..utils.utils import mesh2d
 | 
				
			|||||||
class Vis2DAbstract(pl.Callback):
 | 
					class Vis2DAbstract(pl.Callback):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __init__(self,
 | 
					    def __init__(self,
 | 
				
			||||||
                 data,
 | 
					                 data=None,
 | 
				
			||||||
                 title="Prototype Visualization",
 | 
					                 title="Prototype Visualization",
 | 
				
			||||||
                 cmap="viridis",
 | 
					                 cmap="viridis",
 | 
				
			||||||
                 xlabel="Data dimension 1",
 | 
					                 xlabel="Data dimension 1",
 | 
				
			||||||
@@ -32,6 +32,7 @@ class Vis2DAbstract(pl.Callback):
 | 
				
			|||||||
                 block=False):
 | 
					                 block=False):
 | 
				
			||||||
        super().__init__()
 | 
					        super().__init__()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if data:
 | 
				
			||||||
            if isinstance(data, Dataset):
 | 
					            if isinstance(data, Dataset):
 | 
				
			||||||
                x, y = next(iter(DataLoader(data, batch_size=len(data))))
 | 
					                x, y = next(iter(DataLoader(data, batch_size=len(data))))
 | 
				
			||||||
            elif isinstance(data, torch.utils.data.DataLoader):
 | 
					            elif isinstance(data, torch.utils.data.DataLoader):
 | 
				
			||||||
@@ -48,6 +49,9 @@ class Vis2DAbstract(pl.Callback):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
            self.x_train = x
 | 
					            self.x_train = x
 | 
				
			||||||
            self.y_train = y
 | 
					            self.y_train = y
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            self.x_train = None
 | 
				
			||||||
 | 
					            self.y_train = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        self.title = title
 | 
					        self.title = title
 | 
				
			||||||
        self.xlabel = xlabel
 | 
					        self.xlabel = xlabel
 | 
				
			||||||
@@ -136,10 +140,13 @@ class VisGLVQ2D(Vis2DAbstract):
 | 
				
			|||||||
        plabels = pl_module.prototype_labels
 | 
					        plabels = pl_module.prototype_labels
 | 
				
			||||||
        x_train, y_train = self.x_train, self.y_train
 | 
					        x_train, y_train = self.x_train, self.y_train
 | 
				
			||||||
        ax = self.setup_ax()
 | 
					        ax = self.setup_ax()
 | 
				
			||||||
        self.plot_data(ax, x_train, y_train)
 | 
					 | 
				
			||||||
        self.plot_protos(ax, protos, plabels)
 | 
					        self.plot_protos(ax, protos, plabels)
 | 
				
			||||||
        x = np.vstack((x_train, protos))
 | 
					        if x_train is not None:
 | 
				
			||||||
        mesh_input, xx, yy = mesh2d(x, self.border, self.resolution)
 | 
					            self.plot_data(ax, x_train, y_train)
 | 
				
			||||||
 | 
					            mesh_input, xx, yy = mesh2d(np.vstack([x_train, protos]),
 | 
				
			||||||
 | 
					                                        self.border, self.resolution)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            mesh_input, xx, yy = mesh2d(protos, self.border, self.resolution)
 | 
				
			||||||
        _components = pl_module.proto_layer._components
 | 
					        _components = pl_module.proto_layer._components
 | 
				
			||||||
        mesh_input = torch.from_numpy(mesh_input).type_as(_components)
 | 
					        mesh_input = torch.from_numpy(mesh_input).type_as(_components)
 | 
				
			||||||
        y_pred = pl_module.predict(mesh_input)
 | 
					        y_pred = pl_module.predict(mesh_input)
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user