CIVis/visualiser.py

344 lines
12 KiB
Python
Raw Normal View History

import warnings
from pathlib import Path
import matplotlib as mpl
import matplotlib.patches as mpatches
import matplotlib.pyplot as plt
import numpy as np
from mpl_toolkits.axes_grid1.axes_divider import make_axes_locatable
from numpy.typing import NDArray
from sklearn.decomposition import PCA
def unique_targets(targets):
"""Generate dictionary of integer indices and the corresponding targets."""
return {target: i for i, target in enumerate(sorted(set(targets)))}
def label2int(targets):
"""Convert string and/or integer labels to integer labels (necessary for
torch)."""
_unique_targets = unique_targets(targets)
targets = [_unique_targets[target] for target in targets]
return targets
def omega2lambda(omega: NDArray) -> NDArray:
"""Normalize Ω (Omega) and return the CCM (Λ, Lambda)."""
omega = omega / np.sqrt(np.trace(omega.T @ omega))
lmbd = omega @ omega.T
return lmbd
def vis_ccm(
ccm: NDArray,
labels: NDArray | list[str | int] | bool = True,
title="Classification correlation matrix",
cmap: str = "bwr",
important_labels: int | None = None,
no_diagonal: bool = False,
show_cip: bool = True,
show_crp: bool = False,
sort: bool = False,
cutout: int | None = None,
fontsize: float = 6,
filename: str | Path | None = None,
):
"""Visualise the Classification Correlation Matrix (CCM), Classification
Influence Profile (CIP) and Classification Relevance Profile (CRP).
Parameters:
-----------
ccm : NDArray
Classification Correlation Matrix (CCM) or Λ (Lambda) matrix.
labels : NDArray or list of integers and/or strings or bool
Supply a list of integers and/or strings. If `True`, generates an
integer index. If `False`, doesn't show labels.
title : str or None
Title of the figure.
cmap : str, default="bwr"
For options of matplotlib colormaps, see:
https://matplotlib.org/stable/tutorials/colors/colormaps.html
important_labels : int, default=None
Only show the `important_labels` most important labels.
no_diagonal : bool, default=False
Set the diagonal elements to zero.
show_cip : bool, default=True
Show the CIP.
show_crp : bool, default=True
Show the CRP.
sort : bool, default=False
Sort the features by their importance for classification (CIP).
cutout : int or None, default=None
If set, cuts out the `cutout` most important features and only shows
those.
fontsize : float, default=6
Overall fontsize.
filename : str or Path
If a filename is given, the figure is saved to that path.
"""
font = {"fontsize": fontsize}
# profiles
cip = np.abs(ccm).sum(axis=0) / len(ccm)
crp = np.diagonal(ccm)
if labels is True:
ax_labels = list(range(len(ccm)))
else:
ax_labels = labels
if sort or cutout:
if cutout is not None:
greatest_importance = np.delete(np.argsort(cip)[::-1][:cutout], 2)
else:
greatest_importance = np.delete(np.argsort(cip)[::-1], 2)
new_ccm = np.zeros((len(greatest_importance), len(greatest_importance)))
for i, x in enumerate(greatest_importance):
for j, y in enumerate(greatest_importance):
new_ccm[i, j] = ccm[x, y]
ccm = new_ccm
ax_labels = [
ax_labels[i].replace("m_", "").replace("p_", "")
for i in greatest_importance.tolist()
]
# ax_labels = [ax_labels[i] for i in greatest_importance.tolist()]
# profiles
# cip = np.abs(ccm).sum(axis=0)
# crp = np.diagonal(ccm)
cip = cip[greatest_importance]
crp = crp[greatest_importance]
if no_diagonal:
np.fill_diagonal(ccm, 0)
if important_labels:
greatest_importance = np.argsort(cip)[::-1][:important_labels]
ax_labels = [
l if i in greatest_importance else "" for i, l in enumerate(ax_labels)
]
if not labels:
orig_labels = ax_labels
ax_labels = [""] * len(ax_labels)
fig, ax = plt.subplots()
# matrix
plt.title(title)
maxi = max(-np.amin(ccm), np.amax(ccm)) * 1.1
vmin = -maxi
vmax = maxi
cax = ax.imshow(ccm, interpolation="nearest", cmap=cmap, vmin=vmin, vmax=vmax)
# plt.xticks(np.arange(0, len(ax_labels)))
plt.yticks(np.arange(0, len(ax_labels)))
# ax.set_xticklabels([""] * len(ax_labels))
ax.set_yticklabels(ax_labels, fontdict=font)
if show_cip or show_crp:
plt.xticks([])
else:
plt.xticks(np.arange(0, len(ax_labels)))
ax.set_xticklabels(ax_labels, fontdict=font)
divider = make_axes_locatable(plt.gca())
# colorbar
ax_color = divider.append_axes("right", "5%", pad="10%")
plt.colorbar(cax, cax=ax_color)
if show_cip:
ax_inf = divider.append_axes("bottom", "30%", pad="5%")
ax_inf.plot(cip, marker=".", lw=0.6, ms=3, color="grey", label="CIP")
if show_crp:
ax_inf.plot(crp, marker=".", lw=0.6, ms=3, color="darkgrey", label="CRP")
ax_inf.fill_between(range(len(ccm)), cip, crp, alpha=0.3, color="lightgrey")
plt.legend(bbox_to_anchor=(1.43, 0.7), loc="center right")
if show_cip or show_crp:
ax_inf.set_xticks(np.arange(0, len(ax_labels)))
ax_inf.set_xticklabels(ax_labels, fontdict=font)
plt.xticks(rotation=45, ha="right")
ax_inf.set_ylim(bottom=0)
# plt.legend(bbox_to_anchor=(1.43, 0.7), loc="center right", fontsize=8)
plt.tight_layout()
if filename:
plt.savefig(filename + ".png", dpi=600, bbox_inches="tight")
plt.show()
plt.close()
def vis_data(
data: NDArray,
targets: NDArray | list[int | str] | None = None,
prototypes: NDArray | None = None,
prototype_labels: NDArray | list[int | str] | None = None,
omega: NDArray | None = None,
dimensions: int | None = None,
title: str | None = None,
size: float = 20,
alpha: float = 0.5,
cmap: str = "tab10",
proto_marker: str = "o",
proto_color: str = "r",
filename: str | Path | None = None,
):
"""Visualise multidimensional data in one to three dimensions.
Dimensionality reduction is achieved through either Principal Component
Analysis (PCA) or, if the Ω (Omega) matrix is supplied, by projection into the
latent space.
Parameters
----------
data : NDArray
Data array.
targets : NDArray or list of integers and/or strings or None, default=None
Supply a list of targets (class labels) as integers and/or strings. If
`None`, it doesn't show labels.
prototypes : NDArray or None, default=None
If prototype array is supplied, they are displayed and highlighted (see
`proto_marker` and `proto_color`).
prototype_labels : NDArray or list of integers and/or strings or None,
default=None
Supply a list of targets (class labels) as integers and/or strings. If
`None`, it doesn't show labels.
omega : NDArray or None, default=None
Ω (Omega) matrix for projection into latent space (see General Matrix
Learning Vector Quantization (GMLVQ)). If not supplied, PCA is chosen
for dimensionality reduction.
dimensions : int or None, default=None
Desired number of projection dimensions. If not supplied, chooses
either according to shape of Ω (Omega) or `2` for PCA.
title : str
Title of the figure.
size : float, default=20
Size of the points.
alpha : float, default=0.5
Alpha transparency for data points (e.g. for better visibility of
prototypes).
cmap : str, default="tab10"
Where to choose colours for classes from. For options of matplotlib
colormaps, see:
https://matplotlib.org/stable/tutorials/colors/colormaps.html
proto_marker : str, default="o"
For more options, see:
https://matplotlib.org/stable/api/markers_api.html#module-matplotlib.markers
proto_color : str, default="r"
Highlighting color for prototypes. For more options, see:
https://matplotlib.org/stable/gallery/color/named_colors.html#base-colors
filename : str or Path
If a filename is given, the figure is saved to that path.
"""
fig = plt.figure(1, figsize=(8, 6))
n_data, n_dimensions = data.shape
target_reverse = {i: target for i, target in enumerate(sorted(set(targets)))}
method = ""
if omega is not None:
assert 1 <= omega.shape[1] <= 3, (
"Omega needs to be within the 3 dimensions "
"that humans are accustomed to."
)
if dimensions is None:
dimensions = omega.shape[1]
data = data @ omega
prototypes = prototypes @ omega
method = "omega"
if omega.shape[1] == 1:
data = np.concatenate((data, np.zeros((data.shape[0], 1))), axis=1)
prototypes = np.concatenate(
(prototypes, np.zeros((prototypes.shape[0], 1))), axis=1
)
if data.shape[1] >= 3:
if dimensions is None:
dimensions = 2
pca = PCA(n_components=dimensions)
data = pca.fit_transform(data)
prototypes = pca.transform(prototypes)
method = "PCA"
# display data
if data.shape[1] <= 2:
ax = fig.add_subplot(111)
scatter = ax.scatter(
*data.T,
c=label2int(targets),
marker="o",
s=size,
edgecolors="none",
cmap=cmap,
alpha=alpha,
)
elif data.shape[1] == 3:
ax = fig.add_subplot(111, projection="3d", elev=-150, azim=110)
scatter = ax.scatter(
*data.T,
c=label2int(targets),
marker="o",
s=size,
edgecolors="none",
cmap=cmap,
alpha=alpha,
)
# display prototypes(?)
if prototypes is not None and prototype_labels is None:
warnings.warn(
"You forgot to supply labels to the prototypes to " "color them correctly."
)
if prototypes is not None and prototype_labels is not None:
ax.scatter(
*prototypes.T,
c=label2int(prototype_labels),
cmap=cmap,
marker=proto_marker,
s=size * 2,
edgecolors=proto_color,
linewidths=4,
)
# axis labels
if method == "PCA":
if pca.n_components >= 1:
ax.set_xlabel(
"1st eigenvector " f"({pca.explained_variance_ratio_[0]:.2%})"
)
if pca.n_components >= 2:
ax.set_ylabel(
"2nd eigenvector " f"({pca.explained_variance_ratio_[1]:.2%})"
)
if pca.n_components == 3:
ax.set_zlabel(
f"3rd eigenvector " f"({pca.explained_variance_ratio_[2]:.2%})"
)
ax.set_title("PCA directions")
elif method == "omega":
if data.shape[1] >= 1:
ax.set_xlabel("projection dimension 1")
if data.shape[1] >= 2:
ax.set_ylabel("projection dimension 2")
if data.shape[1] == 3:
ax.set_zlabel("projection dimension 3")
ax.set_title("Omega transform")
cmap = mpl.colormaps[cmap].resampled(len(set(targets)))
ax.legend(
handles=[
mpatches.Patch(color=cmap(t, alpha=alpha), label=target_reverse[t])
for t in target_reverse
]
)
if filename:
plt.savefig(filename + ".png", dpi=600, bbox_inches="tight")
plt.show()