Initial commit.
Includes: visualiser.py with documentation for visualising the CCM and data.
This commit is contained in:
parent
f65705e8c2
commit
76668e552c
343
visualiser.py
Normal file
343
visualiser.py
Normal file
@ -0,0 +1,343 @@
|
|||||||
|
import warnings
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import matplotlib as mpl
|
||||||
|
import matplotlib.patches as mpatches
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import numpy as np
|
||||||
|
from mpl_toolkits.axes_grid1.axes_divider import make_axes_locatable
|
||||||
|
from numpy.typing import NDArray
|
||||||
|
from sklearn.decomposition import PCA
|
||||||
|
|
||||||
|
|
||||||
|
def unique_targets(targets):
|
||||||
|
"""Generate dictionary of integer indices and the corresponding targets."""
|
||||||
|
|
||||||
|
return {target: i for i, target in enumerate(sorted(set(targets)))}
|
||||||
|
|
||||||
|
|
||||||
|
def label2int(targets):
|
||||||
|
"""Convert string and/or integer labels to integer labels (necessary for
|
||||||
|
torch)."""
|
||||||
|
|
||||||
|
_unique_targets = unique_targets(targets)
|
||||||
|
targets = [_unique_targets[target] for target in targets]
|
||||||
|
return targets
|
||||||
|
|
||||||
|
|
||||||
|
def omega2lambda(omega: NDArray) -> NDArray:
|
||||||
|
"""Normalize Ω (Omega) and return the CCM (Λ, Lambda)."""
|
||||||
|
|
||||||
|
omega = omega / np.sqrt(np.trace(omega.T @ omega))
|
||||||
|
lmbd = omega @ omega.T
|
||||||
|
return lmbd
|
||||||
|
|
||||||
|
|
||||||
|
def vis_ccm(
|
||||||
|
ccm: NDArray,
|
||||||
|
labels: NDArray | list[str | int] | bool = True,
|
||||||
|
title="Classification correlation matrix",
|
||||||
|
cmap: str = "bwr",
|
||||||
|
important_labels: int | None = None,
|
||||||
|
no_diagonal: bool = False,
|
||||||
|
show_cip: bool = True,
|
||||||
|
show_crp: bool = False,
|
||||||
|
sort: bool = False,
|
||||||
|
cutout: int | None = None,
|
||||||
|
fontsize: float = 6,
|
||||||
|
filename: str | Path | None = None,
|
||||||
|
):
|
||||||
|
"""Visualise the Classification Correlation Matrix (CCM), Classification
|
||||||
|
Influence Profile (CIP) and Classification Relevance Profile (CRP).
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
-----------
|
||||||
|
ccm : NDArray
|
||||||
|
Classification Correlation Matrix (CCM) or Λ (Lambda) matrix.
|
||||||
|
labels : NDArray or list of integers and/or strings or bool
|
||||||
|
Supply a list of integers and/or strings. If `True`, generates an
|
||||||
|
integer index. If `False`, doesn't show labels.
|
||||||
|
title : str or None
|
||||||
|
Title of the figure.
|
||||||
|
cmap : str, default="bwr"
|
||||||
|
For options of matplotlib colormaps, see:
|
||||||
|
https://matplotlib.org/stable/tutorials/colors/colormaps.html
|
||||||
|
important_labels : int, default=None
|
||||||
|
Only show the `important_labels` most important labels.
|
||||||
|
no_diagonal : bool, default=False
|
||||||
|
Set the diagonal elements to zero.
|
||||||
|
show_cip : bool, default=True
|
||||||
|
Show the CIP.
|
||||||
|
show_crp : bool, default=True
|
||||||
|
Show the CRP.
|
||||||
|
sort : bool, default=False
|
||||||
|
Sort the features by their importance for classification (CIP).
|
||||||
|
cutout : int or None, default=None
|
||||||
|
If set, cuts out the `cutout` most important features and only shows
|
||||||
|
those.
|
||||||
|
fontsize : float, default=6
|
||||||
|
Overall fontsize.
|
||||||
|
filename : str or Path
|
||||||
|
If a filename is given, the figure is saved to that path.
|
||||||
|
"""
|
||||||
|
|
||||||
|
font = {"fontsize": fontsize}
|
||||||
|
# profiles
|
||||||
|
cip = np.abs(ccm).sum(axis=0) / len(ccm)
|
||||||
|
crp = np.diagonal(ccm)
|
||||||
|
if labels is True:
|
||||||
|
ax_labels = list(range(len(ccm)))
|
||||||
|
else:
|
||||||
|
ax_labels = labels
|
||||||
|
|
||||||
|
if sort or cutout:
|
||||||
|
if cutout is not None:
|
||||||
|
greatest_importance = np.delete(np.argsort(cip)[::-1][:cutout], 2)
|
||||||
|
else:
|
||||||
|
greatest_importance = np.delete(np.argsort(cip)[::-1], 2)
|
||||||
|
|
||||||
|
new_ccm = np.zeros((len(greatest_importance), len(greatest_importance)))
|
||||||
|
for i, x in enumerate(greatest_importance):
|
||||||
|
for j, y in enumerate(greatest_importance):
|
||||||
|
new_ccm[i, j] = ccm[x, y]
|
||||||
|
ccm = new_ccm
|
||||||
|
ax_labels = [
|
||||||
|
ax_labels[i].replace("m_", "").replace("p_", "")
|
||||||
|
for i in greatest_importance.tolist()
|
||||||
|
]
|
||||||
|
# ax_labels = [ax_labels[i] for i in greatest_importance.tolist()]
|
||||||
|
|
||||||
|
# profiles
|
||||||
|
# cip = np.abs(ccm).sum(axis=0)
|
||||||
|
# crp = np.diagonal(ccm)
|
||||||
|
cip = cip[greatest_importance]
|
||||||
|
crp = crp[greatest_importance]
|
||||||
|
|
||||||
|
if no_diagonal:
|
||||||
|
np.fill_diagonal(ccm, 0)
|
||||||
|
|
||||||
|
if important_labels:
|
||||||
|
greatest_importance = np.argsort(cip)[::-1][:important_labels]
|
||||||
|
ax_labels = [
|
||||||
|
l if i in greatest_importance else "" for i, l in enumerate(ax_labels)
|
||||||
|
]
|
||||||
|
if not labels:
|
||||||
|
orig_labels = ax_labels
|
||||||
|
ax_labels = [""] * len(ax_labels)
|
||||||
|
|
||||||
|
fig, ax = plt.subplots()
|
||||||
|
|
||||||
|
# matrix
|
||||||
|
plt.title(title)
|
||||||
|
|
||||||
|
maxi = max(-np.amin(ccm), np.amax(ccm)) * 1.1
|
||||||
|
vmin = -maxi
|
||||||
|
vmax = maxi
|
||||||
|
|
||||||
|
cax = ax.imshow(ccm, interpolation="nearest", cmap=cmap, vmin=vmin, vmax=vmax)
|
||||||
|
|
||||||
|
# plt.xticks(np.arange(0, len(ax_labels)))
|
||||||
|
plt.yticks(np.arange(0, len(ax_labels)))
|
||||||
|
# ax.set_xticklabels([""] * len(ax_labels))
|
||||||
|
ax.set_yticklabels(ax_labels, fontdict=font)
|
||||||
|
if show_cip or show_crp:
|
||||||
|
plt.xticks([])
|
||||||
|
else:
|
||||||
|
plt.xticks(np.arange(0, len(ax_labels)))
|
||||||
|
ax.set_xticklabels(ax_labels, fontdict=font)
|
||||||
|
|
||||||
|
divider = make_axes_locatable(plt.gca())
|
||||||
|
|
||||||
|
# colorbar
|
||||||
|
ax_color = divider.append_axes("right", "5%", pad="10%")
|
||||||
|
plt.colorbar(cax, cax=ax_color)
|
||||||
|
|
||||||
|
if show_cip:
|
||||||
|
ax_inf = divider.append_axes("bottom", "30%", pad="5%")
|
||||||
|
ax_inf.plot(cip, marker=".", lw=0.6, ms=3, color="grey", label="CIP")
|
||||||
|
if show_crp:
|
||||||
|
ax_inf.plot(crp, marker=".", lw=0.6, ms=3, color="darkgrey", label="CRP")
|
||||||
|
ax_inf.fill_between(range(len(ccm)), cip, crp, alpha=0.3, color="lightgrey")
|
||||||
|
plt.legend(bbox_to_anchor=(1.43, 0.7), loc="center right")
|
||||||
|
|
||||||
|
if show_cip or show_crp:
|
||||||
|
ax_inf.set_xticks(np.arange(0, len(ax_labels)))
|
||||||
|
ax_inf.set_xticklabels(ax_labels, fontdict=font)
|
||||||
|
plt.xticks(rotation=45, ha="right")
|
||||||
|
ax_inf.set_ylim(bottom=0)
|
||||||
|
|
||||||
|
# plt.legend(bbox_to_anchor=(1.43, 0.7), loc="center right", fontsize=8)
|
||||||
|
plt.tight_layout()
|
||||||
|
if filename:
|
||||||
|
plt.savefig(filename + ".png", dpi=600, bbox_inches="tight")
|
||||||
|
plt.show()
|
||||||
|
plt.close()
|
||||||
|
|
||||||
|
|
||||||
|
def vis_data(
|
||||||
|
data: NDArray,
|
||||||
|
targets: NDArray | list[int | str] | None = None,
|
||||||
|
prototypes: NDArray | None = None,
|
||||||
|
prototype_labels: NDArray | list[int | str] | None = None,
|
||||||
|
omega: NDArray | None = None,
|
||||||
|
dimensions: int | None = None,
|
||||||
|
title: str | None = None,
|
||||||
|
size: float = 20,
|
||||||
|
alpha: float = 0.5,
|
||||||
|
cmap: str = "tab10",
|
||||||
|
proto_marker: str = "o",
|
||||||
|
proto_color: str = "r",
|
||||||
|
filename: str | Path | None = None,
|
||||||
|
):
|
||||||
|
"""Visualise multidimensional data in one to three dimensions.
|
||||||
|
Dimensionality reduction is achieved through either Principal Component
|
||||||
|
Analysis (PCA) or, if the Ω (Omega) matrix is supplied, by projection into the
|
||||||
|
latent space.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
data : NDArray
|
||||||
|
Data array.
|
||||||
|
targets : NDArray or list of integers and/or strings or None, default=None
|
||||||
|
Supply a list of targets (class labels) as integers and/or strings. If
|
||||||
|
`None`, it doesn't show labels.
|
||||||
|
prototypes : NDArray or None, default=None
|
||||||
|
If prototype array is supplied, they are displayed and highlighted (see
|
||||||
|
`proto_marker` and `proto_color`).
|
||||||
|
prototype_labels : NDArray or list of integers and/or strings or None,
|
||||||
|
default=None
|
||||||
|
Supply a list of targets (class labels) as integers and/or strings. If
|
||||||
|
`None`, it doesn't show labels.
|
||||||
|
omega : NDArray or None, default=None
|
||||||
|
Ω (Omega) matrix for projection into latent space (see General Matrix
|
||||||
|
Learning Vector Quantization (GMLVQ)). If not supplied, PCA is chosen
|
||||||
|
for dimensionality reduction.
|
||||||
|
dimensions : int or None, default=None
|
||||||
|
Desired number of projection dimensions. If not supplied, chooses
|
||||||
|
either according to shape of Ω (Omega) or `2` for PCA.
|
||||||
|
title : str
|
||||||
|
Title of the figure.
|
||||||
|
size : float, default=20
|
||||||
|
Size of the points.
|
||||||
|
alpha : float, default=0.5
|
||||||
|
Alpha transparency for data points (e.g. for better visibility of
|
||||||
|
prototypes).
|
||||||
|
cmap : str, default="tab10"
|
||||||
|
Where to choose colours for classes from. For options of matplotlib
|
||||||
|
colormaps, see:
|
||||||
|
https://matplotlib.org/stable/tutorials/colors/colormaps.html
|
||||||
|
proto_marker : str, default="o"
|
||||||
|
For more options, see:
|
||||||
|
https://matplotlib.org/stable/api/markers_api.html#module-matplotlib.markers
|
||||||
|
proto_color : str, default="r"
|
||||||
|
Highlighting color for prototypes. For more options, see:
|
||||||
|
https://matplotlib.org/stable/gallery/color/named_colors.html#base-colors
|
||||||
|
filename : str or Path
|
||||||
|
If a filename is given, the figure is saved to that path.
|
||||||
|
"""
|
||||||
|
|
||||||
|
fig = plt.figure(1, figsize=(8, 6))
|
||||||
|
n_data, n_dimensions = data.shape
|
||||||
|
target_reverse = {i: target for i, target in enumerate(sorted(set(targets)))}
|
||||||
|
method = ""
|
||||||
|
|
||||||
|
if omega is not None:
|
||||||
|
assert 1 <= omega.shape[1] <= 3, (
|
||||||
|
"Omega needs to be within the 3 dimensions "
|
||||||
|
"that humans are accustomed to."
|
||||||
|
)
|
||||||
|
if dimensions is None:
|
||||||
|
dimensions = omega.shape[1]
|
||||||
|
data = data @ omega
|
||||||
|
prototypes = prototypes @ omega
|
||||||
|
method = "omega"
|
||||||
|
if omega.shape[1] == 1:
|
||||||
|
data = np.concatenate((data, np.zeros((data.shape[0], 1))), axis=1)
|
||||||
|
prototypes = np.concatenate(
|
||||||
|
(prototypes, np.zeros((prototypes.shape[0], 1))), axis=1
|
||||||
|
)
|
||||||
|
|
||||||
|
if data.shape[1] >= 3:
|
||||||
|
if dimensions is None:
|
||||||
|
dimensions = 2
|
||||||
|
pca = PCA(n_components=dimensions)
|
||||||
|
data = pca.fit_transform(data)
|
||||||
|
prototypes = pca.transform(prototypes)
|
||||||
|
method = "PCA"
|
||||||
|
|
||||||
|
# display data
|
||||||
|
if data.shape[1] <= 2:
|
||||||
|
ax = fig.add_subplot(111)
|
||||||
|
scatter = ax.scatter(
|
||||||
|
*data.T,
|
||||||
|
c=label2int(targets),
|
||||||
|
marker="o",
|
||||||
|
s=size,
|
||||||
|
edgecolors="none",
|
||||||
|
cmap=cmap,
|
||||||
|
alpha=alpha,
|
||||||
|
)
|
||||||
|
|
||||||
|
elif data.shape[1] == 3:
|
||||||
|
ax = fig.add_subplot(111, projection="3d", elev=-150, azim=110)
|
||||||
|
scatter = ax.scatter(
|
||||||
|
*data.T,
|
||||||
|
c=label2int(targets),
|
||||||
|
marker="o",
|
||||||
|
s=size,
|
||||||
|
edgecolors="none",
|
||||||
|
cmap=cmap,
|
||||||
|
alpha=alpha,
|
||||||
|
)
|
||||||
|
|
||||||
|
# display prototypes(?)
|
||||||
|
if prototypes is not None and prototype_labels is None:
|
||||||
|
warnings.warn(
|
||||||
|
"You forgot to supply labels to the prototypes to " "color them correctly."
|
||||||
|
)
|
||||||
|
|
||||||
|
if prototypes is not None and prototype_labels is not None:
|
||||||
|
ax.scatter(
|
||||||
|
*prototypes.T,
|
||||||
|
c=label2int(prototype_labels),
|
||||||
|
cmap=cmap,
|
||||||
|
marker=proto_marker,
|
||||||
|
s=size * 2,
|
||||||
|
edgecolors=proto_color,
|
||||||
|
linewidths=4,
|
||||||
|
)
|
||||||
|
|
||||||
|
# axis labels
|
||||||
|
if method == "PCA":
|
||||||
|
if pca.n_components >= 1:
|
||||||
|
ax.set_xlabel(
|
||||||
|
"1st eigenvector " f"({pca.explained_variance_ratio_[0]:.2%})"
|
||||||
|
)
|
||||||
|
if pca.n_components >= 2:
|
||||||
|
ax.set_ylabel(
|
||||||
|
"2nd eigenvector " f"({pca.explained_variance_ratio_[1]:.2%})"
|
||||||
|
)
|
||||||
|
if pca.n_components == 3:
|
||||||
|
ax.set_zlabel(
|
||||||
|
f"3rd eigenvector " f"({pca.explained_variance_ratio_[2]:.2%})"
|
||||||
|
)
|
||||||
|
ax.set_title("PCA directions")
|
||||||
|
elif method == "omega":
|
||||||
|
if data.shape[1] >= 1:
|
||||||
|
ax.set_xlabel("projection dimension 1")
|
||||||
|
if data.shape[1] >= 2:
|
||||||
|
ax.set_ylabel("projection dimension 2")
|
||||||
|
if data.shape[1] == 3:
|
||||||
|
ax.set_zlabel("projection dimension 3")
|
||||||
|
ax.set_title("Omega transform")
|
||||||
|
|
||||||
|
cmap = mpl.colormaps[cmap].resampled(len(set(targets)))
|
||||||
|
ax.legend(
|
||||||
|
handles=[
|
||||||
|
mpatches.Patch(color=cmap(t, alpha=alpha), label=target_reverse[t])
|
||||||
|
for t in target_reverse
|
||||||
|
]
|
||||||
|
)
|
||||||
|
if filename:
|
||||||
|
plt.savefig(filename + ".png", dpi=600, bbox_inches="tight")
|
||||||
|
plt.show()
|
Loading…
Reference in New Issue
Block a user