feat: add visualize
method to visualization callbacks
All visualization callbacks now contain a `visualize` method that takes an appropriate PyTorchLightning Module and visualizes it without the need for a Trainer. This is to encourage users to perform one-off visualizations after training.
This commit is contained in:
parent
98892afee0
commit
197b728c63
@ -114,16 +114,19 @@ class Vis2DAbstract(pl.Callback):
|
|||||||
else:
|
else:
|
||||||
plt.show(block=self.block)
|
plt.show(block=self.block)
|
||||||
|
|
||||||
|
def on_epoch_end(self, trainer, pl_module):
|
||||||
|
if not self.precheck(trainer):
|
||||||
|
return True
|
||||||
|
self.visualize(pl_module)
|
||||||
|
self.log_and_display(trainer, pl_module)
|
||||||
|
|
||||||
def on_train_end(self, trainer, pl_module):
|
def on_train_end(self, trainer, pl_module):
|
||||||
plt.close()
|
plt.close()
|
||||||
|
|
||||||
|
|
||||||
class VisGLVQ2D(Vis2DAbstract):
|
class VisGLVQ2D(Vis2DAbstract):
|
||||||
|
|
||||||
def on_epoch_end(self, trainer, pl_module):
|
def visualize(self, pl_module):
|
||||||
if not self.precheck(trainer):
|
|
||||||
return True
|
|
||||||
|
|
||||||
protos = pl_module.prototypes
|
protos = pl_module.prototypes
|
||||||
plabels = pl_module.prototype_labels
|
plabels = pl_module.prototype_labels
|
||||||
x_train, y_train = self.x_train, self.y_train
|
x_train, y_train = self.x_train, self.y_train
|
||||||
@ -139,8 +142,6 @@ class VisGLVQ2D(Vis2DAbstract):
|
|||||||
y_pred = y_pred.cpu().reshape(xx.shape)
|
y_pred = y_pred.cpu().reshape(xx.shape)
|
||||||
ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
|
ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
|
||||||
|
|
||||||
self.log_and_display(trainer, pl_module)
|
|
||||||
|
|
||||||
|
|
||||||
class VisSiameseGLVQ2D(Vis2DAbstract):
|
class VisSiameseGLVQ2D(Vis2DAbstract):
|
||||||
|
|
||||||
@ -148,10 +149,7 @@ class VisSiameseGLVQ2D(Vis2DAbstract):
|
|||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self.map_protos = map_protos
|
self.map_protos = map_protos
|
||||||
|
|
||||||
def on_epoch_end(self, trainer, pl_module):
|
def visualize(self, pl_module):
|
||||||
if not self.precheck(trainer):
|
|
||||||
return True
|
|
||||||
|
|
||||||
protos = pl_module.prototypes
|
protos = pl_module.prototypes
|
||||||
plabels = pl_module.prototype_labels
|
plabels = pl_module.prototype_labels
|
||||||
x_train, y_train = self.x_train, self.y_train
|
x_train, y_train = self.x_train, self.y_train
|
||||||
@ -178,8 +176,6 @@ class VisSiameseGLVQ2D(Vis2DAbstract):
|
|||||||
y_pred = y_pred.cpu().reshape(xx.shape)
|
y_pred = y_pred.cpu().reshape(xx.shape)
|
||||||
ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
|
ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
|
||||||
|
|
||||||
self.log_and_display(trainer, pl_module)
|
|
||||||
|
|
||||||
|
|
||||||
class VisGMLVQ2D(Vis2DAbstract):
|
class VisGMLVQ2D(Vis2DAbstract):
|
||||||
|
|
||||||
@ -187,10 +183,7 @@ class VisGMLVQ2D(Vis2DAbstract):
|
|||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self.ev_proj = ev_proj
|
self.ev_proj = ev_proj
|
||||||
|
|
||||||
def on_epoch_end(self, trainer, pl_module):
|
def visualize(self, pl_module):
|
||||||
if not self.precheck(trainer):
|
|
||||||
return True
|
|
||||||
|
|
||||||
protos = pl_module.prototypes
|
protos = pl_module.prototypes
|
||||||
plabels = pl_module.prototype_labels
|
plabels = pl_module.prototype_labels
|
||||||
x_train, y_train = self.x_train, self.y_train
|
x_train, y_train = self.x_train, self.y_train
|
||||||
@ -212,15 +205,10 @@ class VisGMLVQ2D(Vis2DAbstract):
|
|||||||
if self.show_protos:
|
if self.show_protos:
|
||||||
self.plot_protos(ax, protos, plabels)
|
self.plot_protos(ax, protos, plabels)
|
||||||
|
|
||||||
self.log_and_display(trainer, pl_module)
|
|
||||||
|
|
||||||
|
|
||||||
class VisCBC2D(Vis2DAbstract):
|
class VisCBC2D(Vis2DAbstract):
|
||||||
|
|
||||||
def on_epoch_end(self, trainer, pl_module):
|
def visualize(self, pl_module):
|
||||||
if not self.precheck(trainer):
|
|
||||||
return True
|
|
||||||
|
|
||||||
x_train, y_train = self.x_train, self.y_train
|
x_train, y_train = self.x_train, self.y_train
|
||||||
protos = pl_module.components
|
protos = pl_module.components
|
||||||
ax = self.setup_ax(xlabel="Data dimension 1",
|
ax = self.setup_ax(xlabel="Data dimension 1",
|
||||||
@ -236,15 +224,10 @@ class VisCBC2D(Vis2DAbstract):
|
|||||||
|
|
||||||
ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
|
ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
|
||||||
|
|
||||||
self.log_and_display(trainer, pl_module)
|
|
||||||
|
|
||||||
|
|
||||||
class VisNG2D(Vis2DAbstract):
|
class VisNG2D(Vis2DAbstract):
|
||||||
|
|
||||||
def on_epoch_end(self, trainer, pl_module):
|
def visualize(self, pl_module):
|
||||||
if not self.precheck(trainer):
|
|
||||||
return True
|
|
||||||
|
|
||||||
x_train, y_train = self.x_train, self.y_train
|
x_train, y_train = self.x_train, self.y_train
|
||||||
protos = pl_module.prototypes
|
protos = pl_module.prototypes
|
||||||
cmat = pl_module.topology_layer.cmat.cpu().numpy()
|
cmat = pl_module.topology_layer.cmat.cpu().numpy()
|
||||||
@ -264,8 +247,6 @@ class VisNG2D(Vis2DAbstract):
|
|||||||
"k-",
|
"k-",
|
||||||
)
|
)
|
||||||
|
|
||||||
self.log_and_display(trainer, pl_module)
|
|
||||||
|
|
||||||
|
|
||||||
class VisImgComp(Vis2DAbstract):
|
class VisImgComp(Vis2DAbstract):
|
||||||
|
|
||||||
@ -321,14 +302,9 @@ class VisImgComp(Vis2DAbstract):
|
|||||||
dataformats=self.dataformats,
|
dataformats=self.dataformats,
|
||||||
)
|
)
|
||||||
|
|
||||||
def on_epoch_end(self, trainer, pl_module):
|
def visualize(self, pl_module):
|
||||||
if not self.precheck(trainer):
|
|
||||||
return True
|
|
||||||
|
|
||||||
if self.show:
|
if self.show:
|
||||||
components = pl_module.components
|
components = pl_module.components
|
||||||
grid = torchvision.utils.make_grid(components,
|
grid = torchvision.utils.make_grid(components,
|
||||||
nrow=self.num_columns)
|
nrow=self.num_columns)
|
||||||
plt.imshow(grid.permute((1, 2, 0)).cpu(), cmap=self.cmap)
|
plt.imshow(grid.permute((1, 2, 0)).cpu(), cmap=self.cmap)
|
||||||
|
|
||||||
self.log_and_display(trainer, pl_module)
|
|
||||||
|
Loading…
Reference in New Issue
Block a user