diff --git a/prototorch/models/callbacks/visualization.py b/prototorch/models/callbacks/visualization.py new file mode 100644 index 0000000..eb60f67 --- /dev/null +++ b/prototorch/models/callbacks/visualization.py @@ -0,0 +1,261 @@ +import os + +import numpy as np +import torch +from matplotlib import pyplot as plt +from matplotlib.offsetbox import AnchoredText + +from prototorch.utils.celluloid import Camera +from prototorch.utils.colors import color_scheme +from prototorch.utils.utils import (gif_from_dir, make_directory, + prettify_string) + + +class VisWeights(Callback): + """Abstract weight visualization callback.""" + def __init__( + self, + data=None, + ignore_last_output_row=False, + label_map=None, + project_mesh=False, + project_protos=False, + voronoi=False, + axis_off=True, + cmap="viridis", + show=True, + display_logs=True, + display_logs_settings={}, + pause_time=0.5, + border=1, + resolution=10, + interval=False, + save=False, + snap=True, + save_dir="./img", + make_gif=False, + make_mp4=False, + verbose=True, + dpi=500, + fps=5, + figsize=(11, 8.5), # standard paper in inches + prefix="", + distance_layer_index=-1, + **kwargs, + ): + super().__init__(**kwargs) + self.data = data + self.ignore_last_output_row = ignore_last_output_row + self.label_map = label_map + self.voronoi = voronoi + self.axis_off = True + self.project_mesh = project_mesh + self.project_protos = project_protos + self.cmap = cmap + self.show = show + self.display_logs = display_logs + self.display_logs_settings = display_logs_settings + self.pause_time = pause_time + self.border = border + self.resolution = resolution + self.interval = interval + self.save = save + self.snap = snap + self.save_dir = save_dir + self.make_gif = make_gif + self.make_mp4 = make_mp4 + self.verbose = verbose + self.dpi = dpi + self.fps = fps + self.figsize = figsize + self.prefix = prefix + self.distance_layer_index = distance_layer_index + self.title = "Weights Visualization" + make_directory(self.save_dir) + + def _skip_epoch(self, epoch): + if self.interval: + if epoch % self.interval != 0: + return True + return False + + def _clean_and_setup_ax(self): + ax = self.ax + if not self.snap: + ax.cla() + ax.set_title(self.title) + if self.axis_off: + ax.axis("off") + + def _savefig(self, fignum, orientation="horizontal"): + figname = f"{self.save_dir}/{self.prefix}{fignum:05d}.png" + figsize = self.figsize + if orientation == "vertical": + figsize = figsize[::-1] + elif orientation == "horizontal": + pass + else: + pass + self.fig.set_size_inches(figsize, forward=False) + self.fig.savefig(figname, dpi=self.dpi) + + def _show_and_save(self, epoch): + if self.show: + plt.pause(self.pause_time) + if self.save: + self._savefig(epoch) + if self.snap: + self.camera.snap() + + def _display_logs(self, ax, epoch, logs): + if self.display_logs: + settings = dict( + loc="lower right", + # padding between the text and bounding box + pad=0.5, + # padding between the bounding box and the axes + borderpad=1.0, + # https://matplotlib.org/api/text_api.html#matplotlib.text.Text + prop=dict( + fontfamily="monospace", + fontweight="medium", + fontsize=12, + ), + ) + + # Override settings with self.display_logs_settings. + settings = {**settings, **self.display_logs_settings} + + log_string = f"""Epoch: {epoch:04d}, + val_loss: {logs.get('val_loss', np.nan):.03f}, + val_acc: {logs.get('val_acc', np.nan):.03f}, + loss: {logs.get('loss', np.nan):.03f}, + acc: {logs.get('acc', np.nan):.03f} + """ + log_string = prettify_string(log_string, end="") + # https://matplotlib.org/api/offsetbox_api.html#matplotlib.offsetbox.AnchoredText + anchored_text = AnchoredText(log_string, **settings) + self.ax.add_artist(anchored_text) + + def on_train_start(self, trainer, pl_module, logs={}): + self.fig = plt.figure(self.title) + self.fig.set_size_inches(self.figsize, forward=False) + self.ax = self.fig.add_subplot(111) + self.camera = Camera(self.fig) + + def on_train_end(self, trainer, pl_module, logs={}): + if self.make_gif: + gif_from_dir(directory=self.save_dir, + prefix=self.prefix, + duration=1.0 / self.fps) + if self.snap and self.make_mp4: + animation = self.camera.animate() + vid = os.path.join(self.save_dir, f"{self.prefix}animation.mp4") + if self.verbose: + print(f"Saving mp4 under {vid}.") + animation.save(vid, fps=self.fps, dpi=self.dpi) + + +class VisPointProtos(VisWeights): + """Visualization of prototypes. + .. TODO:: + Still in Progress. + """ + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.title = "Point Prototypes Visualization" + self.data_scatter_settings = { + "marker": "o", + "s": 30, + "edgecolor": "k", + "cmap": self.cmap, + } + self.protos_scatter_settings = { + "marker": "D", + "s": 50, + "edgecolor": "k", + "cmap": self.cmap, + } + + def on_epoch_start(self, trainer, pl_module, logs={}): + epoch = trainer.current_epoch + if self._skip_epoch(epoch): + return True + + self._clean_and_setup_ax() + + protos = pl_module.prototypes + labels = pl_module.proto_layer.prototype_labels.detach().cpu().numpy() + + if self.project_protos: + protos = self.model.projection(protos).numpy() + + color_map = color_scheme(n=len(set(labels)), + cmap=self.cmap, + zero_indexed=True) + # TODO Get rid of the assumption y values in [0, num_of_classes] + label_colors = [color_map[l] for l in labels] + + if self.data is not None: + x, y = self.data + # TODO Get rid of the assumption y values in [0, num_of_classes] + y_colors = [color_map[l] for l in y] + # x = self.model.projection(x) + if not isinstance(x, np.ndarray): + x = x.numpy() + + # Plot data points. + self.ax.scatter(x[:, 0], + x[:, 1], + c=y_colors, + **self.data_scatter_settings) + + # Paint decision regions. + if self.voronoi: + border = self.border + resolution = self.resolution + x = np.vstack((x, protos)) + x_min, x_max = x[:, 0].min(), x[:, 0].max() + y_min, y_max = x[:, 1].min(), x[:, 1].max() + x_min, x_max = x_min - border, x_max + border + y_min, y_max = y_min - border, y_max + border + try: + xx, yy = np.meshgrid( + np.arange(x_min, x_max, (x_max - x_min) / resolution), + np.arange(y_min, y_max, (x_max - x_min) / resolution), + ) + except ValueError as ve: + print(ve) + raise ValueError(f"x_min: {x_min}, x_max: {x_max}. " + f"x_min - x_max is {x_max - x_min}.") + except MemoryError as me: + print(me) + raise ValueError("Too many points. " + "Try reducing the resolution.") + mesh_input = np.c_[xx.ravel(), yy.ravel()] + + # Predict mesh labels. + if self.project_mesh: + mesh_input = self.model.projection(mesh_input) + + y_pred = pl_module.predict(torch.Tensor(mesh_input)) + y_pred = y_pred.reshape(xx.shape) + + # Plot voronoi regions. + self.ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35) + + self.ax.set_xlim(left=x_min + 0, right=x_max - 0) + self.ax.set_ylim(bottom=y_min + 0, top=y_max - 0) + + # Plot prototypes. + self.ax.scatter(protos[:, 0], + protos[:, 1], + c=label_colors, + **self.protos_scatter_settings) + + # self._show_and_save(epoch) + + def on_epoch_end(self, trainer, pl_module, logs={}): + epoch = trainer.current_epoch + self._display_logs(self.ax, epoch, logs) + self._show_and_save(epoch)