From ff7a1e93d21997cf94adb6e4545d3a27b9564249 Mon Sep 17 00:00:00 2001 From: Jensun Ravichandran Date: Sun, 9 May 2021 20:53:03 +0200 Subject: [PATCH] Refactor visualization callbacks --- prototorch/models/vis.py | 111 ++++++++++++++++++++------------------- 1 file changed, 57 insertions(+), 54 deletions(-) diff --git a/prototorch/models/vis.py b/prototorch/models/vis.py index 6099bc3..7f21e48 100644 --- a/prototorch/models/vis.py +++ b/prototorch/models/vis.py @@ -269,6 +269,7 @@ class Vis2DAbstract(pl.Callback): cmap="viridis", border=1, resolution=50, + show_protos=True, tensorboard=False, show_last_only=False, pause_time=0.1, @@ -288,11 +289,17 @@ class Vis2DAbstract(pl.Callback): self.cmap = cmap self.border = border self.resolution = resolution + self.show_protos = show_protos self.tensorboard = tensorboard self.show_last_only = show_last_only self.pause_time = pause_time self.block = block + def precheck(self, trainer): + if self.show_last_only: + if trainer.current_epoch != trainer.max_epochs - 1: + return + def setup_ax(self, xlabel=None, ylabel=None): ax = self.fig.gca() ax.cla() @@ -312,6 +319,28 @@ class Vis2DAbstract(pl.Callback): mesh_input = np.c_[xx.ravel(), yy.ravel()] return mesh_input, xx, yy + def plot_data(self, ax, x, y): + ax.scatter( + x[:, 0], + x[:, 1], + c=y, + cmap=self.cmap, + edgecolor="k", + marker="o", + s=30, + ) + + def plot_protos(self, ax, protos, plabels): + ax.scatter( + protos[:, 0], + protos[:, 1], + c=plabels, + cmap=self.cmap, + edgecolor="k", + marker="D", + s=50, + ) + def add_to_tensorboard(self, trainer, pl_module): tb = pl_module.logger.experiment tb.add_figure(tag=f"{self.title}", @@ -330,115 +359,89 @@ class Vis2DAbstract(pl.Callback): class VisGLVQ2D(Vis2DAbstract): def on_epoch_end(self, trainer, pl_module): - if self.show_last_only: - if trainer.current_epoch != trainer.max_epochs - 1: - return + self.precheck(trainer) + protos = pl_module.prototypes plabels = pl_module.prototype_labels x_train, y_train = self.x_train, self.y_train ax = self.setup_ax(xlabel="Data dimension 1", ylabel="Data dimension 2") - ax.scatter(x_train[:, 0], x_train[:, 1], c=y_train, edgecolor="k") - ax.scatter( - protos[:, 0], - protos[:, 1], - c=plabels, - cmap=self.cmap, - edgecolor="k", - marker="D", - s=50, - ) + self.plot_data(ax, x_train, y_train) + self.plot_protos(ax, protos, plabels) x = np.vstack((x_train, protos)) mesh_input, xx, yy = self.get_mesh_input(x) y_pred = pl_module.predict(torch.Tensor(mesh_input)) y_pred = y_pred.reshape(xx.shape) - ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35) - # ax.set_xlim(left=x_min + 0, right=x_max - 0) - # ax.set_ylim(bottom=y_min + 0, top=y_max - 0) self.log_and_display(trainer, pl_module) class VisSiameseGLVQ2D(Vis2DAbstract): + def __init__(self, *args, map_protos=True, **kwargs): + super().__init__(*args, **kwargs) + self.map_protos = map_protos + def on_epoch_end(self, trainer, pl_module): + self.precheck(trainer) + protos = pl_module.prototypes plabels = pl_module.prototype_labels x_train, y_train = self.x_train, self.y_train x_train = pl_module.backbone(torch.Tensor(x_train)).detach() - protos = pl_module.backbone(torch.Tensor(protos)).detach() + if self.map_protos: + protos = pl_module.backbone(torch.Tensor(protos)).detach() ax = self.setup_ax() - ax.scatter(x_train[:, 0], x_train[:, 1], c=y_train, edgecolor="k") - ax.scatter( - protos[:, 0], - protos[:, 1], - c=plabels, - cmap=self.cmap, - edgecolor="k", - marker="D", - s=50, - ) - x = np.vstack((x_train, protos)) - mesh_input, xx, yy = self.get_mesh_input(x) + self.plot_data(ax, x_train, y_train) + if self.show_protos: + self.plot_protos(ax, protos, plabels) + x = np.vstack((x_train, protos)) + mesh_input, xx, yy = self.get_mesh_input(x) + else: + mesh_input, xx, yy = self.get_mesh_input(x_train) y_pred = pl_module.predict_latent(torch.Tensor(mesh_input)) y_pred = y_pred.reshape(xx.shape) - ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35) - # ax.set_xlim(left=x_min + 0, right=x_max - 0) - # ax.set_ylim(bottom=y_min + 0, top=y_max - 0) self.log_and_display(trainer, pl_module) class VisCBC2D(Vis2DAbstract): def on_epoch_end(self, trainer, pl_module): + self.precheck(trainer) + x_train, y_train = self.x_train, self.y_train protos = pl_module.components ax = self.setup_ax(xlabel="Data dimension 1", ylabel="Data dimension 2") - ax.scatter(x_train[:, 0], x_train[:, 1], c=y_train, edgecolor="k") - ax.scatter( - protos[:, 0], - protos[:, 1], - c="w", - cmap=self.cmap, - edgecolor="k", - marker="D", - s=50, - ) + self.plot_data(ax, x_train, y_train) + self.plot_protos(ax, protos, plabels) x = np.vstack((x_train, protos)) mesh_input, xx, yy = self.get_mesh_input(x) y_pred = pl_module.predict(torch.Tensor(mesh_input)) y_pred = y_pred.reshape(xx.shape) ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35) - # ax.set_xlim(left=x_min + 0, right=x_max - 0) - # ax.set_ylim(bottom=y_min + 0, top=y_max - 0) self.log_and_display(trainer, pl_module) class VisNG2D(Vis2DAbstract): def on_epoch_end(self, trainer, pl_module): + self.precheck(trainer) + x_train, y_train = self.x_train, self.y_train protos = pl_module.prototypes cmat = pl_module.topology_layer.cmat.cpu().numpy() ax = self.setup_ax(xlabel="Data dimension 1", ylabel="Data dimension 2") - ax.scatter(x_train[:, 0], x_train[:, 1], c=y_train, edgecolor="k") - ax.scatter( - protos[:, 0], - protos[:, 1], - c="k", - edgecolor="k", - marker="D", - s=50, - ) + self.plot_data(ax, x_train, y_train) + self.plot_protos(ax, protos, "w") # Draw connections for i in range(len(protos)): - for j in range(len(protos)): + for j in range(i, len(protos)): if cmat[i][j]: ax.plot( [protos[i, 0], protos[j, 0]],