feat: add visualize method to visualization callbacks
				
					
				
			All visualization callbacks now contain a `visualize` method that takes an appropriate PyTorchLightning Module and visualizes it without the need for a Trainer. This is to encourage users to perform one-off visualizations after training.
This commit is contained in:
		| @@ -114,16 +114,19 @@ class Vis2DAbstract(pl.Callback): | ||||
|             else: | ||||
|                 plt.show(block=self.block) | ||||
|  | ||||
|     def on_epoch_end(self, trainer, pl_module): | ||||
|         if not self.precheck(trainer): | ||||
|             return True | ||||
|         self.visualize(pl_module) | ||||
|         self.log_and_display(trainer, pl_module) | ||||
|  | ||||
|     def on_train_end(self, trainer, pl_module): | ||||
|         plt.close() | ||||
|  | ||||
|  | ||||
| class VisGLVQ2D(Vis2DAbstract): | ||||
|  | ||||
|     def on_epoch_end(self, trainer, pl_module): | ||||
|         if not self.precheck(trainer): | ||||
|             return True | ||||
|  | ||||
|     def visualize(self, pl_module): | ||||
|         protos = pl_module.prototypes | ||||
|         plabels = pl_module.prototype_labels | ||||
|         x_train, y_train = self.x_train, self.y_train | ||||
| @@ -139,8 +142,6 @@ class VisGLVQ2D(Vis2DAbstract): | ||||
|         y_pred = y_pred.cpu().reshape(xx.shape) | ||||
|         ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35) | ||||
|  | ||||
|         self.log_and_display(trainer, pl_module) | ||||
|  | ||||
|  | ||||
| class VisSiameseGLVQ2D(Vis2DAbstract): | ||||
|  | ||||
| @@ -148,10 +149,7 @@ class VisSiameseGLVQ2D(Vis2DAbstract): | ||||
|         super().__init__(*args, **kwargs) | ||||
|         self.map_protos = map_protos | ||||
|  | ||||
|     def on_epoch_end(self, trainer, pl_module): | ||||
|         if not self.precheck(trainer): | ||||
|             return True | ||||
|  | ||||
|     def visualize(self, pl_module): | ||||
|         protos = pl_module.prototypes | ||||
|         plabels = pl_module.prototype_labels | ||||
|         x_train, y_train = self.x_train, self.y_train | ||||
| @@ -178,8 +176,6 @@ class VisSiameseGLVQ2D(Vis2DAbstract): | ||||
|         y_pred = y_pred.cpu().reshape(xx.shape) | ||||
|         ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35) | ||||
|  | ||||
|         self.log_and_display(trainer, pl_module) | ||||
|  | ||||
|  | ||||
| class VisGMLVQ2D(Vis2DAbstract): | ||||
|  | ||||
| @@ -187,10 +183,7 @@ class VisGMLVQ2D(Vis2DAbstract): | ||||
|         super().__init__(*args, **kwargs) | ||||
|         self.ev_proj = ev_proj | ||||
|  | ||||
|     def on_epoch_end(self, trainer, pl_module): | ||||
|         if not self.precheck(trainer): | ||||
|             return True | ||||
|  | ||||
|     def visualize(self, pl_module): | ||||
|         protos = pl_module.prototypes | ||||
|         plabels = pl_module.prototype_labels | ||||
|         x_train, y_train = self.x_train, self.y_train | ||||
| @@ -212,15 +205,10 @@ class VisGMLVQ2D(Vis2DAbstract): | ||||
|         if self.show_protos: | ||||
|             self.plot_protos(ax, protos, plabels) | ||||
|  | ||||
|         self.log_and_display(trainer, pl_module) | ||||
|  | ||||
|  | ||||
| class VisCBC2D(Vis2DAbstract): | ||||
|  | ||||
|     def on_epoch_end(self, trainer, pl_module): | ||||
|         if not self.precheck(trainer): | ||||
|             return True | ||||
|  | ||||
|     def visualize(self, pl_module): | ||||
|         x_train, y_train = self.x_train, self.y_train | ||||
|         protos = pl_module.components | ||||
|         ax = self.setup_ax(xlabel="Data dimension 1", | ||||
| @@ -236,15 +224,10 @@ class VisCBC2D(Vis2DAbstract): | ||||
|  | ||||
|         ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35) | ||||
|  | ||||
|         self.log_and_display(trainer, pl_module) | ||||
|  | ||||
|  | ||||
| class VisNG2D(Vis2DAbstract): | ||||
|  | ||||
|     def on_epoch_end(self, trainer, pl_module): | ||||
|         if not self.precheck(trainer): | ||||
|             return True | ||||
|  | ||||
|     def visualize(self, pl_module): | ||||
|         x_train, y_train = self.x_train, self.y_train | ||||
|         protos = pl_module.prototypes | ||||
|         cmat = pl_module.topology_layer.cmat.cpu().numpy() | ||||
| @@ -264,8 +247,6 @@ class VisNG2D(Vis2DAbstract): | ||||
|                         "k-", | ||||
|                     ) | ||||
|  | ||||
|         self.log_and_display(trainer, pl_module) | ||||
|  | ||||
|  | ||||
| class VisImgComp(Vis2DAbstract): | ||||
|  | ||||
| @@ -321,14 +302,9 @@ class VisImgComp(Vis2DAbstract): | ||||
|             dataformats=self.dataformats, | ||||
|         ) | ||||
|  | ||||
|     def on_epoch_end(self, trainer, pl_module): | ||||
|         if not self.precheck(trainer): | ||||
|             return True | ||||
|  | ||||
|     def visualize(self, pl_module): | ||||
|         if self.show: | ||||
|             components = pl_module.components | ||||
|             grid = torchvision.utils.make_grid(components, | ||||
|                                                nrow=self.num_columns) | ||||
|             plt.imshow(grid.permute((1, 2, 0)).cpu(), cmap=self.cmap) | ||||
|  | ||||
|         self.log_and_display(trainer, pl_module) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user