From 197b728c6322a17717260d47afd4eecfdd00e918 Mon Sep 17 00:00:00 2001 From: Jensun Ravichandran Date: Wed, 2 Feb 2022 21:45:44 +0100 Subject: [PATCH] feat: add `visualize` method to visualization callbacks All visualization callbacks now contain a `visualize` method that takes an appropriate PyTorchLightning Module and visualizes it without the need for a Trainer. This is to encourage users to perform one-off visualizations after training. --- prototorch/models/vis.py | 48 ++++++++++------------------------------ 1 file changed, 12 insertions(+), 36 deletions(-) diff --git a/prototorch/models/vis.py b/prototorch/models/vis.py index 49724a0..d57fc4f 100644 --- a/prototorch/models/vis.py +++ b/prototorch/models/vis.py @@ -114,16 +114,19 @@ class Vis2DAbstract(pl.Callback): else: plt.show(block=self.block) + def on_epoch_end(self, trainer, pl_module): + if not self.precheck(trainer): + return True + self.visualize(pl_module) + self.log_and_display(trainer, pl_module) + def on_train_end(self, trainer, pl_module): plt.close() class VisGLVQ2D(Vis2DAbstract): - def on_epoch_end(self, trainer, pl_module): - if not self.precheck(trainer): - return True - + def visualize(self, pl_module): protos = pl_module.prototypes plabels = pl_module.prototype_labels x_train, y_train = self.x_train, self.y_train @@ -139,8 +142,6 @@ class VisGLVQ2D(Vis2DAbstract): y_pred = y_pred.cpu().reshape(xx.shape) ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35) - self.log_and_display(trainer, pl_module) - class VisSiameseGLVQ2D(Vis2DAbstract): @@ -148,10 +149,7 @@ class VisSiameseGLVQ2D(Vis2DAbstract): super().__init__(*args, **kwargs) self.map_protos = map_protos - def on_epoch_end(self, trainer, pl_module): - if not self.precheck(trainer): - return True - + def visualize(self, pl_module): protos = pl_module.prototypes plabels = pl_module.prototype_labels x_train, y_train = self.x_train, self.y_train @@ -178,8 +176,6 @@ class VisSiameseGLVQ2D(Vis2DAbstract): y_pred = y_pred.cpu().reshape(xx.shape) ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35) - self.log_and_display(trainer, pl_module) - class VisGMLVQ2D(Vis2DAbstract): @@ -187,10 +183,7 @@ class VisGMLVQ2D(Vis2DAbstract): super().__init__(*args, **kwargs) self.ev_proj = ev_proj - def on_epoch_end(self, trainer, pl_module): - if not self.precheck(trainer): - return True - + def visualize(self, pl_module): protos = pl_module.prototypes plabels = pl_module.prototype_labels x_train, y_train = self.x_train, self.y_train @@ -212,15 +205,10 @@ class VisGMLVQ2D(Vis2DAbstract): if self.show_protos: self.plot_protos(ax, protos, plabels) - self.log_and_display(trainer, pl_module) - class VisCBC2D(Vis2DAbstract): - def on_epoch_end(self, trainer, pl_module): - if not self.precheck(trainer): - return True - + def visualize(self, pl_module): x_train, y_train = self.x_train, self.y_train protos = pl_module.components ax = self.setup_ax(xlabel="Data dimension 1", @@ -236,15 +224,10 @@ class VisCBC2D(Vis2DAbstract): ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35) - self.log_and_display(trainer, pl_module) - class VisNG2D(Vis2DAbstract): - def on_epoch_end(self, trainer, pl_module): - if not self.precheck(trainer): - return True - + def visualize(self, pl_module): x_train, y_train = self.x_train, self.y_train protos = pl_module.prototypes cmat = pl_module.topology_layer.cmat.cpu().numpy() @@ -264,8 +247,6 @@ class VisNG2D(Vis2DAbstract): "k-", ) - self.log_and_display(trainer, pl_module) - class VisImgComp(Vis2DAbstract): @@ -321,14 +302,9 @@ class VisImgComp(Vis2DAbstract): dataformats=self.dataformats, ) - def on_epoch_end(self, trainer, pl_module): - if not self.precheck(trainer): - return True - + def visualize(self, pl_module): if self.show: components = pl_module.components grid = torchvision.utils.make_grid(components, nrow=self.num_columns) plt.imshow(grid.permute((1, 2, 0)).cpu(), cmap=self.cmap) - - self.log_and_display(trainer, pl_module)