import warnings from pathlib import Path import matplotlib as mpl import matplotlib.patches as mpatches import matplotlib.pyplot as plt import numpy as np from mpl_toolkits.axes_grid1.axes_divider import make_axes_locatable from numpy.typing import NDArray from sklearn.decomposition import PCA def unique_targets(targets): """Generate dictionary of integer indices and the corresponding targets.""" return {target: i for i, target in enumerate(sorted(set(targets)))} def label2int(targets): """Convert string and/or integer labels to integer labels (necessary for torch).""" _unique_targets = unique_targets(targets) targets = [_unique_targets[target] for target in targets] return targets def omega2lambda(omega: NDArray) -> NDArray: """Normalize Ω (Omega) and return the CCM (Λ, Lambda).""" omega = omega / np.sqrt(np.trace(omega.T @ omega)) lmbd = omega @ omega.T return lmbd def vis_ccm( ccm: NDArray, labels: NDArray | list[str | int] | bool = True, title="Classification correlation matrix", cmap: str = "bwr", important_labels: int | None = None, no_diagonal: bool = False, show_cip: bool = True, show_crp: bool = False, sort: bool = False, cutout: int | None = None, fontsize: float = 6, filename: str | Path | None = None, ): """Visualise the Classification Correlation Matrix (CCM), Classification Influence Profile (CIP) and Classification Relevance Profile (CRP). Parameters: ----------- ccm : NDArray Classification Correlation Matrix (CCM) or Λ (Lambda) matrix. labels : NDArray or list of integers and/or strings or bool Supply a list of integers and/or strings. If `True`, generates an integer index. If `False`, doesn't show labels. title : str or None Title of the figure. cmap : str, default="bwr" For options of matplotlib colormaps, see: https://matplotlib.org/stable/tutorials/colors/colormaps.html important_labels : int, default=None Only show the `important_labels` most important labels. no_diagonal : bool, default=False Set the diagonal elements to zero. show_cip : bool, default=True Show the CIP. show_crp : bool, default=True Show the CRP. sort : bool, default=False Sort the features by their importance for classification (CIP). cutout : int or None, default=None If set, cuts out the `cutout` most important features and only shows those. fontsize : float, default=6 Overall fontsize. filename : str or Path If a filename is given, the figure is saved to that path. """ font = {"fontsize": fontsize} # profiles cip = np.abs(ccm).sum(axis=0) / len(ccm) crp = np.diagonal(ccm) if labels is True: ax_labels = list(range(len(ccm))) else: ax_labels = labels if sort or cutout: if cutout is not None: greatest_importance = np.delete(np.argsort(cip)[::-1][:cutout], 2) else: greatest_importance = np.delete(np.argsort(cip)[::-1], 2) new_ccm = np.zeros((len(greatest_importance), len(greatest_importance))) for i, x in enumerate(greatest_importance): for j, y in enumerate(greatest_importance): new_ccm[i, j] = ccm[x, y] ccm = new_ccm ax_labels = [ ax_labels[i].replace("m_", "").replace("p_", "") for i in greatest_importance.tolist() ] # ax_labels = [ax_labels[i] for i in greatest_importance.tolist()] # profiles # cip = np.abs(ccm).sum(axis=0) # crp = np.diagonal(ccm) cip = cip[greatest_importance] crp = crp[greatest_importance] if no_diagonal: np.fill_diagonal(ccm, 0) if important_labels: greatest_importance = np.argsort(cip)[::-1][:important_labels] ax_labels = [ l if i in greatest_importance else "" for i, l in enumerate(ax_labels) ] if not labels: orig_labels = ax_labels ax_labels = [""] * len(ax_labels) fig, ax = plt.subplots() # matrix plt.title(title) maxi = max(-np.amin(ccm), np.amax(ccm)) * 1.1 vmin = -maxi vmax = maxi cax = ax.imshow(ccm, interpolation="nearest", cmap=cmap, vmin=vmin, vmax=vmax) # plt.xticks(np.arange(0, len(ax_labels))) plt.yticks(np.arange(0, len(ax_labels))) # ax.set_xticklabels([""] * len(ax_labels)) ax.set_yticklabels(ax_labels, fontdict=font) if show_cip or show_crp: plt.xticks([]) else: plt.xticks(np.arange(0, len(ax_labels))) ax.set_xticklabels(ax_labels, fontdict=font) divider = make_axes_locatable(plt.gca()) # colorbar ax_color = divider.append_axes("right", "5%", pad="10%") plt.colorbar(cax, cax=ax_color) if show_cip: ax_inf = divider.append_axes("bottom", "30%", pad="5%") ax_inf.plot(cip, marker=".", lw=0.6, ms=3, color="grey", label="CIP") if show_crp: ax_inf.plot(crp, marker=".", lw=0.6, ms=3, color="darkgrey", label="CRP") ax_inf.fill_between(range(len(ccm)), cip, crp, alpha=0.3, color="lightgrey") plt.legend(bbox_to_anchor=(1.43, 0.7), loc="center right") if show_cip or show_crp: ax_inf.set_xticks(np.arange(0, len(ax_labels))) ax_inf.set_xticklabels(ax_labels, fontdict=font) plt.xticks(rotation=45, ha="right") ax_inf.set_ylim(bottom=0) # plt.legend(bbox_to_anchor=(1.43, 0.7), loc="center right", fontsize=8) plt.tight_layout() if filename: plt.savefig(filename + ".png", dpi=600, bbox_inches="tight") plt.show() plt.close() def vis_data( data: NDArray, targets: NDArray | list[int | str] | None = None, prototypes: NDArray | None = None, prototype_labels: NDArray | list[int | str] | None = None, omega: NDArray | None = None, dimensions: int | None = None, title: str | None = None, size: float = 20, alpha: float = 0.5, cmap: str = "tab10", proto_marker: str = "o", proto_color: str = "r", filename: str | Path | None = None, ): """Visualise multidimensional data in one to three dimensions. Dimensionality reduction is achieved through either Principal Component Analysis (PCA) or, if the Ω (Omega) matrix is supplied, by projection into the latent space. Parameters ---------- data : NDArray Data array. targets : NDArray or list of integers and/or strings or None, default=None Supply a list of targets (class labels) as integers and/or strings. If `None`, it doesn't show labels. prototypes : NDArray or None, default=None If prototype array is supplied, they are displayed and highlighted (see `proto_marker` and `proto_color`). prototype_labels : NDArray or list of integers and/or strings or None, default=None Supply a list of targets (class labels) as integers and/or strings. If `None`, it doesn't show labels. omega : NDArray or None, default=None Ω (Omega) matrix for projection into latent space (see General Matrix Learning Vector Quantization (GMLVQ)). If not supplied, PCA is chosen for dimensionality reduction. dimensions : int or None, default=None Desired number of projection dimensions. If not supplied, chooses either according to shape of Ω (Omega) or `2` for PCA. title : str Title of the figure. size : float, default=20 Size of the points. alpha : float, default=0.5 Alpha transparency for data points (e.g. for better visibility of prototypes). cmap : str, default="tab10" Where to choose colours for classes from. For options of matplotlib colormaps, see: https://matplotlib.org/stable/tutorials/colors/colormaps.html proto_marker : str, default="o" For more options, see: https://matplotlib.org/stable/api/markers_api.html#module-matplotlib.markers proto_color : str, default="r" Highlighting color for prototypes. For more options, see: https://matplotlib.org/stable/gallery/color/named_colors.html#base-colors filename : str or Path If a filename is given, the figure is saved to that path. """ fig = plt.figure(1, figsize=(8, 6)) n_data, n_dimensions = data.shape target_reverse = {i: target for i, target in enumerate(sorted(set(targets)))} method = "" if omega is not None: assert 1 <= omega.shape[1] <= 3, ( "Omega needs to be within the 3 dimensions " "that humans are accustomed to." ) if dimensions is None: dimensions = omega.shape[1] data = data @ omega prototypes = prototypes @ omega method = "omega" if omega.shape[1] == 1: data = np.concatenate((data, np.zeros((data.shape[0], 1))), axis=1) prototypes = np.concatenate( (prototypes, np.zeros((prototypes.shape[0], 1))), axis=1 ) if data.shape[1] >= 3: if dimensions is None: dimensions = 2 pca = PCA(n_components=dimensions) data = pca.fit_transform(data) prototypes = pca.transform(prototypes) method = "PCA" # display data if data.shape[1] <= 2: ax = fig.add_subplot(111) scatter = ax.scatter( *data.T, c=label2int(targets), marker="o", s=size, edgecolors="none", cmap=cmap, alpha=alpha, ) elif data.shape[1] == 3: ax = fig.add_subplot(111, projection="3d", elev=-150, azim=110) scatter = ax.scatter( *data.T, c=label2int(targets), marker="o", s=size, edgecolors="none", cmap=cmap, alpha=alpha, ) # display prototypes(?) if prototypes is not None and prototype_labels is None: warnings.warn( "You forgot to supply labels to the prototypes to " "color them correctly." ) if prototypes is not None and prototype_labels is not None: ax.scatter( *prototypes.T, c=label2int(prototype_labels), cmap=cmap, marker=proto_marker, s=size * 2, edgecolors=proto_color, linewidths=4, ) # axis labels if method == "PCA": if pca.n_components >= 1: ax.set_xlabel( "1st eigenvector " f"({pca.explained_variance_ratio_[0]:.2%})" ) if pca.n_components >= 2: ax.set_ylabel( "2nd eigenvector " f"({pca.explained_variance_ratio_[1]:.2%})" ) if pca.n_components == 3: ax.set_zlabel( f"3rd eigenvector " f"({pca.explained_variance_ratio_[2]:.2%})" ) ax.set_title("PCA directions") elif method == "omega": if data.shape[1] >= 1: ax.set_xlabel("projection dimension 1") if data.shape[1] >= 2: ax.set_ylabel("projection dimension 2") if data.shape[1] == 3: ax.set_zlabel("projection dimension 3") ax.set_title("Omega transform") cmap = mpl.colormaps[cmap].resampled(len(set(targets))) ax.legend( handles=[ mpatches.Patch(color=cmap(t, alpha=alpha), label=target_reverse[t]) for t in target_reverse ] ) if filename: plt.savefig(filename + ".png", dpi=600, bbox_inches="tight") plt.show()