From 77b7b59badd96ff89292aeae721fcf1fbc2ef15b Mon Sep 17 00:00:00 2001 From: Jensun Ravichandran Date: Mon, 17 May 2021 16:59:22 +0200 Subject: [PATCH] Clean visualization callbacks --- prototorch/models/vis.py | 292 ++++----------------------------------- 1 file changed, 25 insertions(+), 267 deletions(-) diff --git a/prototorch/models/vis.py b/prototorch/models/vis.py index 327043f..17105f5 100644 --- a/prototorch/models/vis.py +++ b/prototorch/models/vis.py @@ -6,262 +6,11 @@ import torch import torchvision from matplotlib import pyplot as plt from matplotlib.offsetbox import AnchoredText -from torch.utils.data import DataLoader, Dataset - from prototorch.utils.celluloid import Camera from prototorch.utils.colors import color_scheme from prototorch.utils.utils import (gif_from_dir, make_directory, prettify_string) - - -class VisWeights(pl.Callback): - """Abstract weight visualization callback.""" - def __init__( - self, - data=None, - ignore_last_output_row=False, - label_map=None, - project_mesh=False, - project_protos=False, - voronoi=False, - axis_off=True, - cmap="viridis", - show=True, - display_logs=True, - display_logs_settings={}, - pause_time=0.5, - border=1, - resolution=10, - interval=False, - save=False, - snap=True, - save_dir="./img", - make_gif=False, - make_mp4=False, - verbose=True, - dpi=500, - fps=5, - figsize=(11, 8.5), # standard paper in inches - prefix="", - distance_layer_index=-1, - **kwargs, - ): - super().__init__(**kwargs) - self.data = data - self.ignore_last_output_row = ignore_last_output_row - self.label_map = label_map - self.voronoi = voronoi - self.axis_off = True - self.project_mesh = project_mesh - self.project_protos = project_protos - self.cmap = cmap - self.show = show - self.display_logs = display_logs - self.display_logs_settings = display_logs_settings - self.pause_time = pause_time - self.border = border - self.resolution = resolution - self.interval = interval - self.save = save - self.snap = snap - self.save_dir = save_dir - self.make_gif = make_gif - self.make_mp4 = make_mp4 - self.verbose = verbose - self.dpi = dpi - self.fps = fps - self.figsize = figsize - self.prefix = prefix - self.distance_layer_index = distance_layer_index - self.title = "Weights Visualization" - make_directory(self.save_dir) - - def _skip_epoch(self, epoch): - if self.interval: - if epoch % self.interval != 0: - return True - return False - - def _clean_and_setup_ax(self): - ax = self.ax - if not self.snap: - ax.cla() - ax.set_title(self.title) - if self.axis_off: - ax.axis("off") - - def _savefig(self, fignum, orientation="horizontal"): - figname = f"{self.save_dir}/{self.prefix}{fignum:05d}.png" - figsize = self.figsize - if orientation == "vertical": - figsize = figsize[::-1] - elif orientation == "horizontal": - pass - else: - pass - self.fig.set_size_inches(figsize, forward=False) - self.fig.savefig(figname, dpi=self.dpi) - - def _show_and_save(self, epoch): - if self.show: - plt.pause(self.pause_time) - if self.save: - self._savefig(epoch) - if self.snap: - self.camera.snap() - - def _display_logs(self, ax, epoch, logs): - if self.display_logs: - settings = dict( - loc="lower right", - # padding between the text and bounding box - pad=0.5, - # padding between the bounding box and the axes - borderpad=1.0, - # https://matplotlib.org/api/text_api.html#matplotlib.text.Text - prop=dict( - fontfamily="monospace", - fontweight="medium", - fontsize=12, - ), - ) - - # Override settings with self.display_logs_settings. - settings = {**settings, **self.display_logs_settings} - - log_string = f"""Epoch: {epoch:04d}, - val_loss: {logs.get('val_loss', np.nan):.03f}, - val_acc: {logs.get('val_acc', np.nan):.03f}, - loss: {logs.get('loss', np.nan):.03f}, - acc: {logs.get('acc', np.nan):.03f} - """ - log_string = prettify_string(log_string, end="") - # https://matplotlib.org/api/offsetbox_api.html#matplotlib.offsetbox.AnchoredText - anchored_text = AnchoredText(log_string, **settings) - self.ax.add_artist(anchored_text) - - def on_train_start(self, trainer, pl_module, logs={}): - self.fig = plt.figure(self.title) - self.fig.set_size_inches(self.figsize, forward=False) - self.ax = self.fig.add_subplot(111) - self.camera = Camera(self.fig) - - def on_train_end(self, trainer, pl_module, logs={}): - if self.make_gif: - gif_from_dir(directory=self.save_dir, - prefix=self.prefix, - duration=1.0 / self.fps) - if self.snap and self.make_mp4: - animation = self.camera.animate() - vid = os.path.join(self.save_dir, f"{self.prefix}animation.mp4") - if self.verbose: - print(f"Saving mp4 under {vid}.") - animation.save(vid, fps=self.fps, dpi=self.dpi) - - -class VisPointProtos(VisWeights): - """Visualization of prototypes. - .. TODO:: - Still in Progress. - """ - def __init__(self, **kwargs): - super().__init__(**kwargs) - self.title = "Point Prototypes Visualization" - self.data_scatter_settings = { - "marker": "o", - "s": 30, - "edgecolor": "k", - "cmap": self.cmap, - } - self.protos_scatter_settings = { - "marker": "D", - "s": 50, - "edgecolor": "k", - "cmap": self.cmap, - } - - def on_epoch_start(self, trainer, pl_module, logs={}): - epoch = trainer.current_epoch - if self._skip_epoch(epoch): - return True - - self._clean_and_setup_ax() - - protos = pl_module.prototypes - labels = pl_module.proto_layer.prototype_labels.detach().cpu().numpy() - - if self.project_protos: - protos = self.model.projection(protos).numpy() - - color_map = color_scheme(n=len(set(labels)), - cmap=self.cmap, - zero_indexed=True) - # TODO Get rid of the assumption y values in [0, num_of_classes] - label_colors = [color_map[l] for l in labels] - - if self.data is not None: - x, y = self.data - # TODO Get rid of the assumption y values in [0, num_of_classes] - y_colors = [color_map[l] for l in y] - # x = self.model.projection(x) - if not isinstance(x, np.ndarray): - x = x.numpy() - - # Plot data points. - self.ax.scatter(x[:, 0], - x[:, 1], - c=y_colors, - **self.data_scatter_settings) - - # Paint decision regions. - if self.voronoi: - border = self.border - resolution = self.resolution - x = np.vstack((x, protos)) - x_min, x_max = x[:, 0].min(), x[:, 0].max() - y_min, y_max = x[:, 1].min(), x[:, 1].max() - x_min, x_max = x_min - border, x_max + border - y_min, y_max = y_min - border, y_max + border - try: - xx, yy = np.meshgrid( - np.arange(x_min, x_max, (x_max - x_min) / resolution), - np.arange(y_min, y_max, (x_max - x_min) / resolution), - ) - except ValueError as ve: - print(ve) - raise ValueError(f"x_min: {x_min}, x_max: {x_max}. " - f"x_min - x_max is {x_max - x_min}.") - except MemoryError as me: - print(me) - raise ValueError("Too many points. " - "Try reducing the resolution.") - mesh_input = np.c_[xx.ravel(), yy.ravel()] - - # Predict mesh labels. - if self.project_mesh: - mesh_input = self.model.projection(mesh_input) - - y_pred = pl_module.predict(torch.Tensor(mesh_input)) - y_pred = y_pred.reshape(xx.shape) - - # Plot voronoi regions. - self.ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35) - - self.ax.set_xlim(left=x_min + 0, right=x_max - 0) - self.ax.set_ylim(bottom=y_min + 0, top=y_max - 0) - - # Plot prototypes. - self.ax.scatter(protos[:, 0], - protos[:, 1], - c=label_colors, - **self.protos_scatter_settings) - - # self._show_and_save(epoch) - - def on_epoch_end(self, trainer, pl_module, logs={}): - epoch = trainer.current_epoch - self._display_logs(self.ax, epoch, logs) - self._show_and_save(epoch) +from torch.utils.data import DataLoader, Dataset class Vis2DAbstract(pl.Callback): @@ -269,8 +18,9 @@ class Vis2DAbstract(pl.Callback): data, title="Prototype Visualization", cmap="viridis", - border=1, - resolution=50, + border=0.1, + resolution=100, + axis_off=False, show_protos=True, show=True, tensorboard=False, @@ -292,6 +42,7 @@ class Vis2DAbstract(pl.Callback): self.cmap = cmap self.border = border self.resolution = resolution + self.axis_off = axis_off self.show_protos = show_protos self.show = show self.tensorboard = tensorboard @@ -309,18 +60,21 @@ class Vis2DAbstract(pl.Callback): ax = self.fig.gca() ax.cla() ax.set_title(self.title) - ax.axis("off") if xlabel: ax.set_xlabel("Data dimension 1") if ylabel: ax.set_ylabel("Data dimension 2") + if self.axis_off: + ax.axis("off") return ax def get_mesh_input(self, x): - x_min, x_max = x[:, 0].min() - self.border, x[:, 0].max() + self.border - y_min, y_max = x[:, 1].min() - self.border, x[:, 1].max() + self.border - xx, yy = np.meshgrid(np.arange(x_min, x_max, 1 / self.resolution), - np.arange(y_min, y_max, 1 / self.resolution)) + x_shift = self.border * np.ptp(x[:, 0]) + y_shift = self.border * np.ptp(x[:, 1]) + x_min, x_max = x[:, 0].min() - x_shift, x[:, 0].max() + x_shift + y_min, y_max = x[:, 1].min() - y_shift, x[:, 1].max() + y_shift + xx, yy = np.meshgrid(np.linspace(x_min, x_max, self.resolution), + np.linspace(y_min, y_max, self.resolution)) mesh_input = np.c_[xx.ravel(), yy.ravel()] return mesh_input, xx, yy @@ -381,8 +135,8 @@ class VisGLVQ2D(Vis2DAbstract): x = np.vstack((x_train, protos)) mesh_input, xx, yy = self.get_mesh_input(x) _components = pl_module.proto_layer._components - y_pred = pl_module.predict( - torch.Tensor(mesh_input).type_as(_components)) + mesh_input = torch.Tensor(mesh_input).type_as(_components) + y_pred = pl_module.predict(mesh_input) y_pred = y_pred.cpu().reshape(xx.shape) ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35) @@ -401,11 +155,14 @@ class VisSiameseGLVQ2D(Vis2DAbstract): protos = pl_module.prototypes plabels = pl_module.prototype_labels x_train, y_train = self.x_train, self.y_train - x_train = pl_module.backbone( - torch.Tensor(x_train).to(pl_module.device)).cpu().detach() + device = pl_module.device + with torch.no_grad(): + x_train = pl_module.backbone(torch.Tensor(x_train).to(device)) + x_train = x_train.cpu().detach() if self.map_protos: - protos = pl_module.backbone( - torch.Tensor(protos).to(pl_module.device)).cpu().detach() + with torch.no_grad(): + protos = pl_module.backbone(torch.Tensor(protos).to(device)) + protos = protos.cpu().detach() ax = self.setup_ax() self.plot_data(ax, x_train, y_train) if self.show_protos: @@ -415,8 +172,9 @@ class VisSiameseGLVQ2D(Vis2DAbstract): else: mesh_input, xx, yy = self.get_mesh_input(x_train) _components = pl_module.proto_layer._components - y_pred = pl_module.predict_latent( - torch.Tensor(mesh_input).type_as(_components)) + mesh_input = torch.Tensor(mesh_input).type_as(_components) + y_pred = pl_module.predict_latent(mesh_input, + map_protos=self.map_protos) y_pred = y_pred.cpu().reshape(xx.shape) ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)