From 76668e552cba9974d6fe7a61e6ee548fdc291baa Mon Sep 17 00:00:00 2001 From: juvilius Date: Tue, 29 Aug 2023 11:25:53 +0200 Subject: [PATCH] Initial commit. Includes: visualiser.py with documentation for visualising the CCM and data. --- visualiser.py | 343 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 343 insertions(+) create mode 100644 visualiser.py diff --git a/visualiser.py b/visualiser.py new file mode 100644 index 0000000..73964d6 --- /dev/null +++ b/visualiser.py @@ -0,0 +1,343 @@ +import warnings +from pathlib import Path + +import matplotlib as mpl +import matplotlib.patches as mpatches +import matplotlib.pyplot as plt +import numpy as np +from mpl_toolkits.axes_grid1.axes_divider import make_axes_locatable +from numpy.typing import NDArray +from sklearn.decomposition import PCA + + +def unique_targets(targets): + """Generate dictionary of integer indices and the corresponding targets.""" + + return {target: i for i, target in enumerate(sorted(set(targets)))} + + +def label2int(targets): + """Convert string and/or integer labels to integer labels (necessary for + torch).""" + + _unique_targets = unique_targets(targets) + targets = [_unique_targets[target] for target in targets] + return targets + + +def omega2lambda(omega: NDArray) -> NDArray: + """Normalize Ω (Omega) and return the CCM (Λ, Lambda).""" + + omega = omega / np.sqrt(np.trace(omega.T @ omega)) + lmbd = omega @ omega.T + return lmbd + + +def vis_ccm( + ccm: NDArray, + labels: NDArray | list[str | int] | bool = True, + title="Classification correlation matrix", + cmap: str = "bwr", + important_labels: int | None = None, + no_diagonal: bool = False, + show_cip: bool = True, + show_crp: bool = False, + sort: bool = False, + cutout: int | None = None, + fontsize: float = 6, + filename: str | Path | None = None, +): + """Visualise the Classification Correlation Matrix (CCM), Classification + Influence Profile (CIP) and Classification Relevance Profile (CRP). + + Parameters: + ----------- + ccm : NDArray + Classification Correlation Matrix (CCM) or Λ (Lambda) matrix. + labels : NDArray or list of integers and/or strings or bool + Supply a list of integers and/or strings. If `True`, generates an + integer index. If `False`, doesn't show labels. + title : str or None + Title of the figure. + cmap : str, default="bwr" + For options of matplotlib colormaps, see: + https://matplotlib.org/stable/tutorials/colors/colormaps.html + important_labels : int, default=None + Only show the `important_labels` most important labels. + no_diagonal : bool, default=False + Set the diagonal elements to zero. + show_cip : bool, default=True + Show the CIP. + show_crp : bool, default=True + Show the CRP. + sort : bool, default=False + Sort the features by their importance for classification (CIP). + cutout : int or None, default=None + If set, cuts out the `cutout` most important features and only shows + those. + fontsize : float, default=6 + Overall fontsize. + filename : str or Path + If a filename is given, the figure is saved to that path. + """ + + font = {"fontsize": fontsize} + # profiles + cip = np.abs(ccm).sum(axis=0) / len(ccm) + crp = np.diagonal(ccm) + if labels is True: + ax_labels = list(range(len(ccm))) + else: + ax_labels = labels + + if sort or cutout: + if cutout is not None: + greatest_importance = np.delete(np.argsort(cip)[::-1][:cutout], 2) + else: + greatest_importance = np.delete(np.argsort(cip)[::-1], 2) + + new_ccm = np.zeros((len(greatest_importance), len(greatest_importance))) + for i, x in enumerate(greatest_importance): + for j, y in enumerate(greatest_importance): + new_ccm[i, j] = ccm[x, y] + ccm = new_ccm + ax_labels = [ + ax_labels[i].replace("m_", "").replace("p_", "") + for i in greatest_importance.tolist() + ] + # ax_labels = [ax_labels[i] for i in greatest_importance.tolist()] + + # profiles + # cip = np.abs(ccm).sum(axis=0) + # crp = np.diagonal(ccm) + cip = cip[greatest_importance] + crp = crp[greatest_importance] + + if no_diagonal: + np.fill_diagonal(ccm, 0) + + if important_labels: + greatest_importance = np.argsort(cip)[::-1][:important_labels] + ax_labels = [ + l if i in greatest_importance else "" for i, l in enumerate(ax_labels) + ] + if not labels: + orig_labels = ax_labels + ax_labels = [""] * len(ax_labels) + + fig, ax = plt.subplots() + + # matrix + plt.title(title) + + maxi = max(-np.amin(ccm), np.amax(ccm)) * 1.1 + vmin = -maxi + vmax = maxi + + cax = ax.imshow(ccm, interpolation="nearest", cmap=cmap, vmin=vmin, vmax=vmax) + + # plt.xticks(np.arange(0, len(ax_labels))) + plt.yticks(np.arange(0, len(ax_labels))) + # ax.set_xticklabels([""] * len(ax_labels)) + ax.set_yticklabels(ax_labels, fontdict=font) + if show_cip or show_crp: + plt.xticks([]) + else: + plt.xticks(np.arange(0, len(ax_labels))) + ax.set_xticklabels(ax_labels, fontdict=font) + + divider = make_axes_locatable(plt.gca()) + + # colorbar + ax_color = divider.append_axes("right", "5%", pad="10%") + plt.colorbar(cax, cax=ax_color) + + if show_cip: + ax_inf = divider.append_axes("bottom", "30%", pad="5%") + ax_inf.plot(cip, marker=".", lw=0.6, ms=3, color="grey", label="CIP") + if show_crp: + ax_inf.plot(crp, marker=".", lw=0.6, ms=3, color="darkgrey", label="CRP") + ax_inf.fill_between(range(len(ccm)), cip, crp, alpha=0.3, color="lightgrey") + plt.legend(bbox_to_anchor=(1.43, 0.7), loc="center right") + + if show_cip or show_crp: + ax_inf.set_xticks(np.arange(0, len(ax_labels))) + ax_inf.set_xticklabels(ax_labels, fontdict=font) + plt.xticks(rotation=45, ha="right") + ax_inf.set_ylim(bottom=0) + + # plt.legend(bbox_to_anchor=(1.43, 0.7), loc="center right", fontsize=8) + plt.tight_layout() + if filename: + plt.savefig(filename + ".png", dpi=600, bbox_inches="tight") + plt.show() + plt.close() + + +def vis_data( + data: NDArray, + targets: NDArray | list[int | str] | None = None, + prototypes: NDArray | None = None, + prototype_labels: NDArray | list[int | str] | None = None, + omega: NDArray | None = None, + dimensions: int | None = None, + title: str | None = None, + size: float = 20, + alpha: float = 0.5, + cmap: str = "tab10", + proto_marker: str = "o", + proto_color: str = "r", + filename: str | Path | None = None, +): + """Visualise multidimensional data in one to three dimensions. + Dimensionality reduction is achieved through either Principal Component + Analysis (PCA) or, if the Ω (Omega) matrix is supplied, by projection into the + latent space. + + Parameters + ---------- + data : NDArray + Data array. + targets : NDArray or list of integers and/or strings or None, default=None + Supply a list of targets (class labels) as integers and/or strings. If + `None`, it doesn't show labels. + prototypes : NDArray or None, default=None + If prototype array is supplied, they are displayed and highlighted (see + `proto_marker` and `proto_color`). + prototype_labels : NDArray or list of integers and/or strings or None, + default=None + Supply a list of targets (class labels) as integers and/or strings. If + `None`, it doesn't show labels. + omega : NDArray or None, default=None + Ω (Omega) matrix for projection into latent space (see General Matrix + Learning Vector Quantization (GMLVQ)). If not supplied, PCA is chosen + for dimensionality reduction. + dimensions : int or None, default=None + Desired number of projection dimensions. If not supplied, chooses + either according to shape of Ω (Omega) or `2` for PCA. + title : str + Title of the figure. + size : float, default=20 + Size of the points. + alpha : float, default=0.5 + Alpha transparency for data points (e.g. for better visibility of + prototypes). + cmap : str, default="tab10" + Where to choose colours for classes from. For options of matplotlib + colormaps, see: + https://matplotlib.org/stable/tutorials/colors/colormaps.html + proto_marker : str, default="o" + For more options, see: + https://matplotlib.org/stable/api/markers_api.html#module-matplotlib.markers + proto_color : str, default="r" + Highlighting color for prototypes. For more options, see: + https://matplotlib.org/stable/gallery/color/named_colors.html#base-colors + filename : str or Path + If a filename is given, the figure is saved to that path. + """ + + fig = plt.figure(1, figsize=(8, 6)) + n_data, n_dimensions = data.shape + target_reverse = {i: target for i, target in enumerate(sorted(set(targets)))} + method = "" + + if omega is not None: + assert 1 <= omega.shape[1] <= 3, ( + "Omega needs to be within the 3 dimensions " + "that humans are accustomed to." + ) + if dimensions is None: + dimensions = omega.shape[1] + data = data @ omega + prototypes = prototypes @ omega + method = "omega" + if omega.shape[1] == 1: + data = np.concatenate((data, np.zeros((data.shape[0], 1))), axis=1) + prototypes = np.concatenate( + (prototypes, np.zeros((prototypes.shape[0], 1))), axis=1 + ) + + if data.shape[1] >= 3: + if dimensions is None: + dimensions = 2 + pca = PCA(n_components=dimensions) + data = pca.fit_transform(data) + prototypes = pca.transform(prototypes) + method = "PCA" + + # display data + if data.shape[1] <= 2: + ax = fig.add_subplot(111) + scatter = ax.scatter( + *data.T, + c=label2int(targets), + marker="o", + s=size, + edgecolors="none", + cmap=cmap, + alpha=alpha, + ) + + elif data.shape[1] == 3: + ax = fig.add_subplot(111, projection="3d", elev=-150, azim=110) + scatter = ax.scatter( + *data.T, + c=label2int(targets), + marker="o", + s=size, + edgecolors="none", + cmap=cmap, + alpha=alpha, + ) + + # display prototypes(?) + if prototypes is not None and prototype_labels is None: + warnings.warn( + "You forgot to supply labels to the prototypes to " "color them correctly." + ) + + if prototypes is not None and prototype_labels is not None: + ax.scatter( + *prototypes.T, + c=label2int(prototype_labels), + cmap=cmap, + marker=proto_marker, + s=size * 2, + edgecolors=proto_color, + linewidths=4, + ) + + # axis labels + if method == "PCA": + if pca.n_components >= 1: + ax.set_xlabel( + "1st eigenvector " f"({pca.explained_variance_ratio_[0]:.2%})" + ) + if pca.n_components >= 2: + ax.set_ylabel( + "2nd eigenvector " f"({pca.explained_variance_ratio_[1]:.2%})" + ) + if pca.n_components == 3: + ax.set_zlabel( + f"3rd eigenvector " f"({pca.explained_variance_ratio_[2]:.2%})" + ) + ax.set_title("PCA directions") + elif method == "omega": + if data.shape[1] >= 1: + ax.set_xlabel("projection dimension 1") + if data.shape[1] >= 2: + ax.set_ylabel("projection dimension 2") + if data.shape[1] == 3: + ax.set_zlabel("projection dimension 3") + ax.set_title("Omega transform") + + cmap = mpl.colormaps[cmap].resampled(len(set(targets))) + ax.legend( + handles=[ + mpatches.Patch(color=cmap(t, alpha=alpha), label=target_reverse[t]) + for t in target_reverse + ] + ) + if filename: + plt.savefig(filename + ".png", dpi=600, bbox_inches="tight") + plt.show()