Move and improve visualization callbacks
This commit is contained in:
		@@ -9,6 +9,7 @@ from prototorch.utils.celluloid import Camera
 | 
				
			|||||||
from prototorch.utils.colors import color_scheme
 | 
					from prototorch.utils.colors import color_scheme
 | 
				
			||||||
from prototorch.utils.utils import (gif_from_dir, make_directory,
 | 
					from prototorch.utils.utils import (gif_from_dir, make_directory,
 | 
				
			||||||
                                    prettify_string)
 | 
					                                    prettify_string)
 | 
				
			||||||
 | 
					from torch.utils.data import DataLoader, Dataset
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class VisWeights(pl.Callback):
 | 
					class VisWeights(pl.Callback):
 | 
				
			||||||
@@ -263,25 +264,54 @@ class VisPointProtos(VisWeights):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
class Vis2DAbstract(pl.Callback):
 | 
					class Vis2DAbstract(pl.Callback):
 | 
				
			||||||
    def __init__(self,
 | 
					    def __init__(self,
 | 
				
			||||||
                 x_train,
 | 
					                 data,
 | 
				
			||||||
                 y_train,
 | 
					 | 
				
			||||||
                 title="Prototype Visualization",
 | 
					                 title="Prototype Visualization",
 | 
				
			||||||
                 cmap="viridis",
 | 
					                 cmap="viridis",
 | 
				
			||||||
                 border=1,
 | 
					                 border=1,
 | 
				
			||||||
 | 
					                 resolution=50,
 | 
				
			||||||
                 tensorboard=False,
 | 
					                 tensorboard=False,
 | 
				
			||||||
                 show_last_only=False,
 | 
					                 show_last_only=False,
 | 
				
			||||||
 | 
					                 pause_time=0.1,
 | 
				
			||||||
                 block=False):
 | 
					                 block=False):
 | 
				
			||||||
        super().__init__()
 | 
					        super().__init__()
 | 
				
			||||||
        self.x_train = x_train
 | 
					
 | 
				
			||||||
        self.y_train = y_train
 | 
					        if isinstance(data, Dataset):
 | 
				
			||||||
 | 
					            x, y = next(iter(DataLoader(data, batch_size=len(data))))
 | 
				
			||||||
 | 
					            x = x.view(len(data), -1)  # flatten
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            x, y = data
 | 
				
			||||||
 | 
					        self.x_train = x
 | 
				
			||||||
 | 
					        self.y_train = y
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        self.title = title
 | 
					        self.title = title
 | 
				
			||||||
        self.fig = plt.figure(self.title)
 | 
					        self.fig = plt.figure(self.title)
 | 
				
			||||||
        self.cmap = cmap
 | 
					        self.cmap = cmap
 | 
				
			||||||
        self.border = border
 | 
					        self.border = border
 | 
				
			||||||
 | 
					        self.resolution = resolution
 | 
				
			||||||
        self.tensorboard = tensorboard
 | 
					        self.tensorboard = tensorboard
 | 
				
			||||||
        self.show_last_only = show_last_only
 | 
					        self.show_last_only = show_last_only
 | 
				
			||||||
 | 
					        self.pause_time = pause_time
 | 
				
			||||||
        self.block = block
 | 
					        self.block = block
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def setup_ax(self, xlabel=None, ylabel=None):
 | 
				
			||||||
 | 
					        ax = self.fig.gca()
 | 
				
			||||||
 | 
					        ax.cla()
 | 
				
			||||||
 | 
					        ax.set_title(self.title)
 | 
				
			||||||
 | 
					        ax.axis("off")
 | 
				
			||||||
 | 
					        if xlabel:
 | 
				
			||||||
 | 
					            ax.set_xlabel("Data dimension 1")
 | 
				
			||||||
 | 
					        if ylabel:
 | 
				
			||||||
 | 
					            ax.set_ylabel("Data dimension 2")
 | 
				
			||||||
 | 
					        return ax
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def get_mesh_input(self, x):
 | 
				
			||||||
 | 
					        x_min, x_max = x[:, 0].min() - self.border, x[:, 0].max() + self.border
 | 
				
			||||||
 | 
					        y_min, y_max = x[:, 1].min() - self.border, x[:, 1].max() + self.border
 | 
				
			||||||
 | 
					        xx, yy = np.meshgrid(np.arange(x_min, x_max, 1 / self.resolution),
 | 
				
			||||||
 | 
					                             np.arange(y_min, y_max, 1 / self.resolution))
 | 
				
			||||||
 | 
					        mesh_input = np.c_[xx.ravel(), yy.ravel()]
 | 
				
			||||||
 | 
					        return mesh_input, xx, yy
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def add_to_tensorboard(self, trainer, pl_module):
 | 
					    def add_to_tensorboard(self, trainer, pl_module):
 | 
				
			||||||
        tb = pl_module.logger.experiment
 | 
					        tb = pl_module.logger.experiment
 | 
				
			||||||
        tb.add_figure(tag=f"{self.title}",
 | 
					        tb.add_figure(tag=f"{self.title}",
 | 
				
			||||||
@@ -289,6 +319,14 @@ class Vis2DAbstract(pl.Callback):
 | 
				
			|||||||
                      global_step=trainer.current_epoch,
 | 
					                      global_step=trainer.current_epoch,
 | 
				
			||||||
                      close=False)
 | 
					                      close=False)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def log_and_display(self, trainer, pl_module):
 | 
				
			||||||
 | 
					        if self.tensorboard:
 | 
				
			||||||
 | 
					            self.add_to_tensorboard(trainer, pl_module)
 | 
				
			||||||
 | 
					        if not self.block:
 | 
				
			||||||
 | 
					            plt.pause(self.pause_time)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            plt.show(block=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class VisGLVQ2D(Vis2DAbstract):
 | 
					class VisGLVQ2D(Vis2DAbstract):
 | 
				
			||||||
    def on_epoch_end(self, trainer, pl_module):
 | 
					    def on_epoch_end(self, trainer, pl_module):
 | 
				
			||||||
@@ -298,12 +336,8 @@ class VisGLVQ2D(Vis2DAbstract):
 | 
				
			|||||||
        protos = pl_module.prototypes
 | 
					        protos = pl_module.prototypes
 | 
				
			||||||
        plabels = pl_module.prototype_labels
 | 
					        plabels = pl_module.prototype_labels
 | 
				
			||||||
        x_train, y_train = self.x_train, self.y_train
 | 
					        x_train, y_train = self.x_train, self.y_train
 | 
				
			||||||
        ax = self.fig.gca()
 | 
					        ax = self.setup_ax(xlabel="Data dimension 1",
 | 
				
			||||||
        ax.cla()
 | 
					                           ylabel="Data dimension 2")
 | 
				
			||||||
        ax.set_title(self.title)
 | 
					 | 
				
			||||||
        ax.axis("off")
 | 
					 | 
				
			||||||
        ax.set_xlabel("Data dimension 1")
 | 
					 | 
				
			||||||
        ax.set_ylabel("Data dimension 2")
 | 
					 | 
				
			||||||
        ax.scatter(x_train[:, 0], x_train[:, 1], c=y_train, edgecolor="k")
 | 
					        ax.scatter(x_train[:, 0], x_train[:, 1], c=y_train, edgecolor="k")
 | 
				
			||||||
        ax.scatter(
 | 
					        ax.scatter(
 | 
				
			||||||
            protos[:, 0],
 | 
					            protos[:, 0],
 | 
				
			||||||
@@ -315,23 +349,15 @@ class VisGLVQ2D(Vis2DAbstract):
 | 
				
			|||||||
            s=50,
 | 
					            s=50,
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        x = np.vstack((x_train, protos))
 | 
					        x = np.vstack((x_train, protos))
 | 
				
			||||||
        x_min, x_max = x[:, 0].min() - 1, x[:, 0].max() + 1
 | 
					        mesh_input, xx, yy = self.get_mesh_input(x)
 | 
				
			||||||
        y_min, y_max = x[:, 1].min() - 1, x[:, 1].max() + 1
 | 
					 | 
				
			||||||
        xx, yy = np.meshgrid(np.arange(x_min, x_max, 1 / 50),
 | 
					 | 
				
			||||||
                             np.arange(y_min, y_max, 1 / 50))
 | 
					 | 
				
			||||||
        mesh_input = np.c_[xx.ravel(), yy.ravel()]
 | 
					 | 
				
			||||||
        y_pred = pl_module.predict(torch.Tensor(mesh_input))
 | 
					        y_pred = pl_module.predict(torch.Tensor(mesh_input))
 | 
				
			||||||
        y_pred = y_pred.reshape(xx.shape)
 | 
					        y_pred = y_pred.reshape(xx.shape)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
 | 
					        ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
 | 
				
			||||||
        ax.set_xlim(left=x_min + 0, right=x_max - 0)
 | 
					        # ax.set_xlim(left=x_min + 0, right=x_max - 0)
 | 
				
			||||||
        ax.set_ylim(bottom=y_min + 0, top=y_max - 0)
 | 
					        # ax.set_ylim(bottom=y_min + 0, top=y_max - 0)
 | 
				
			||||||
        if self.tensorboard:
 | 
					
 | 
				
			||||||
            self.add_to_tensorboard(trainer, pl_module)
 | 
					        self.log_and_display(trainer, pl_module)
 | 
				
			||||||
        if not self.block:
 | 
					 | 
				
			||||||
            plt.pause(0.01)
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            plt.show(block=True)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class VisSiameseGLVQ2D(Vis2DAbstract):
 | 
					class VisSiameseGLVQ2D(Vis2DAbstract):
 | 
				
			||||||
@@ -341,10 +367,7 @@ class VisSiameseGLVQ2D(Vis2DAbstract):
 | 
				
			|||||||
        x_train, y_train = self.x_train, self.y_train
 | 
					        x_train, y_train = self.x_train, self.y_train
 | 
				
			||||||
        x_train = pl_module.backbone(torch.Tensor(x_train)).detach()
 | 
					        x_train = pl_module.backbone(torch.Tensor(x_train)).detach()
 | 
				
			||||||
        protos = pl_module.backbone(torch.Tensor(protos)).detach()
 | 
					        protos = pl_module.backbone(torch.Tensor(protos)).detach()
 | 
				
			||||||
        ax = self.fig.gca()
 | 
					        ax = self.setup_ax()
 | 
				
			||||||
        ax.cla()
 | 
					 | 
				
			||||||
        ax.set_title(self.title)
 | 
					 | 
				
			||||||
        ax.axis("off")
 | 
					 | 
				
			||||||
        ax.scatter(x_train[:, 0], x_train[:, 1], c=y_train, edgecolor="k")
 | 
					        ax.scatter(x_train[:, 0], x_train[:, 1], c=y_train, edgecolor="k")
 | 
				
			||||||
        ax.scatter(
 | 
					        ax.scatter(
 | 
				
			||||||
            protos[:, 0],
 | 
					            protos[:, 0],
 | 
				
			||||||
@@ -356,48 +379,54 @@ class VisSiameseGLVQ2D(Vis2DAbstract):
 | 
				
			|||||||
            s=50,
 | 
					            s=50,
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        x = np.vstack((x_train, protos))
 | 
					        x = np.vstack((x_train, protos))
 | 
				
			||||||
        x_min, x_max = x[:, 0].min() - self.border, x[:, 0].max() + self.border
 | 
					        mesh_input, xx, yy = self.get_mesh_input(x)
 | 
				
			||||||
        y_min, y_max = x[:, 1].min() - self.border, x[:, 1].max() + self.border
 | 
					 | 
				
			||||||
        xx, yy = np.meshgrid(np.arange(x_min, x_max, 1 / 50),
 | 
					 | 
				
			||||||
                             np.arange(y_min, y_max, 1 / 50))
 | 
					 | 
				
			||||||
        mesh_input = np.c_[xx.ravel(), yy.ravel()]
 | 
					 | 
				
			||||||
        y_pred = pl_module.predict_latent(torch.Tensor(mesh_input))
 | 
					        y_pred = pl_module.predict_latent(torch.Tensor(mesh_input))
 | 
				
			||||||
        y_pred = y_pred.reshape(xx.shape)
 | 
					        y_pred = y_pred.reshape(xx.shape)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
 | 
					        ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
 | 
				
			||||||
        ax.set_xlim(left=x_min + 0, right=x_max - 0)
 | 
					        # ax.set_xlim(left=x_min + 0, right=x_max - 0)
 | 
				
			||||||
        ax.set_ylim(bottom=y_min + 0, top=y_max - 0)
 | 
					        # ax.set_ylim(bottom=y_min + 0, top=y_max - 0)
 | 
				
			||||||
        tb = pl_module.logger.experiment
 | 
					 | 
				
			||||||
        tb.add_figure(
 | 
					 | 
				
			||||||
            tag=f"{self.title}",
 | 
					 | 
				
			||||||
            figure=self.fig,
 | 
					 | 
				
			||||||
            global_step=trainer.current_epoch,
 | 
					 | 
				
			||||||
            close=False,
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if self.tensorboard:
 | 
					        self.log_and_display(trainer, pl_module)
 | 
				
			||||||
            self.add_to_tensorboard(trainer, pl_module)
 | 
					
 | 
				
			||||||
        if not self.block:
 | 
					
 | 
				
			||||||
            plt.pause(0.05)
 | 
					class VisCBC2D(Vis2DAbstract):
 | 
				
			||||||
        else:
 | 
					    def on_epoch_end(self, trainer, pl_module):
 | 
				
			||||||
            plt.show(block=True)
 | 
					        x_train, y_train = self.x_train, self.y_train
 | 
				
			||||||
 | 
					        protos = pl_module.components
 | 
				
			||||||
 | 
					        ax = self.setup_ax(xlabel="Data dimension 1",
 | 
				
			||||||
 | 
					                           ylabel="Data dimension 2")
 | 
				
			||||||
 | 
					        ax.scatter(x_train[:, 0], x_train[:, 1], c=y_train, edgecolor="k")
 | 
				
			||||||
 | 
					        ax.scatter(
 | 
				
			||||||
 | 
					            protos[:, 0],
 | 
				
			||||||
 | 
					            protos[:, 1],
 | 
				
			||||||
 | 
					            c="w",
 | 
				
			||||||
 | 
					            cmap=self.cmap,
 | 
				
			||||||
 | 
					            edgecolor="k",
 | 
				
			||||||
 | 
					            marker="D",
 | 
				
			||||||
 | 
					            s=50,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        x = np.vstack((x_train, protos))
 | 
				
			||||||
 | 
					        mesh_input, xx, yy = self.get_mesh_input(x)
 | 
				
			||||||
 | 
					        y_pred = pl_module.predict(torch.Tensor(mesh_input))
 | 
				
			||||||
 | 
					        y_pred = y_pred.reshape(xx.shape)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
 | 
				
			||||||
 | 
					        # ax.set_xlim(left=x_min + 0, right=x_max - 0)
 | 
				
			||||||
 | 
					        # ax.set_ylim(bottom=y_min + 0, top=y_max - 0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.log_and_display(trainer, pl_module)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class VisNG2D(Vis2DAbstract):
 | 
					class VisNG2D(Vis2DAbstract):
 | 
				
			||||||
    def on_epoch_end(self, trainer, pl_module):
 | 
					    def on_epoch_end(self, trainer, pl_module):
 | 
				
			||||||
 | 
					        x_train, y_train = self.x_train, self.y_train
 | 
				
			||||||
        protos = pl_module.prototypes
 | 
					        protos = pl_module.prototypes
 | 
				
			||||||
        cmat = pl_module.topology_layer.cmat.cpu().numpy()
 | 
					        cmat = pl_module.topology_layer.cmat.cpu().numpy()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # Visualize the data and the prototypes
 | 
					        ax = self.setup_ax(xlabel="Data dimension 1",
 | 
				
			||||||
        ax = self.fig.gca()
 | 
					                           ylabel="Data dimension 2")
 | 
				
			||||||
        ax.cla()
 | 
					        ax.scatter(x_train[:, 0], x_train[:, 1], c=y_train, edgecolor="k")
 | 
				
			||||||
        ax.set_title(self.title)
 | 
					 | 
				
			||||||
        ax.set_xlabel("Data dimension 1")
 | 
					 | 
				
			||||||
        ax.set_ylabel("Data dimension 2")
 | 
					 | 
				
			||||||
        ax.scatter(self.x_train[:, 0],
 | 
					 | 
				
			||||||
                   self.x_train[:, 1],
 | 
					 | 
				
			||||||
                   c=self.y_train,
 | 
					 | 
				
			||||||
                   edgecolor="k")
 | 
					 | 
				
			||||||
        ax.scatter(
 | 
					        ax.scatter(
 | 
				
			||||||
            protos[:, 0],
 | 
					            protos[:, 0],
 | 
				
			||||||
            protos[:, 1],
 | 
					            protos[:, 1],
 | 
				
			||||||
@@ -417,9 +446,4 @@ class VisNG2D(Vis2DAbstract):
 | 
				
			|||||||
                        "k-",
 | 
					                        "k-",
 | 
				
			||||||
                    )
 | 
					                    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if self.tensorboard:
 | 
					        self.log_and_display(trainer, pl_module)
 | 
				
			||||||
            self.add_to_tensorboard(trainer, pl_module)
 | 
					 | 
				
			||||||
        if not self.block:
 | 
					 | 
				
			||||||
            plt.pause(0.01)
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            plt.show(block=True)
 | 
					 | 
				
			||||||
		Reference in New Issue
	
	Block a user