Initial commit.
Includes: visualiser.py with documentation for visualising the CCM and data.
This commit is contained in:
parent
f65705e8c2
commit
76668e552c
343
visualiser.py
Normal file
343
visualiser.py
Normal file
@ -0,0 +1,343 @@
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
|
||||
import matplotlib as mpl
|
||||
import matplotlib.patches as mpatches
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
from mpl_toolkits.axes_grid1.axes_divider import make_axes_locatable
|
||||
from numpy.typing import NDArray
|
||||
from sklearn.decomposition import PCA
|
||||
|
||||
|
||||
def unique_targets(targets):
|
||||
"""Generate dictionary of integer indices and the corresponding targets."""
|
||||
|
||||
return {target: i for i, target in enumerate(sorted(set(targets)))}
|
||||
|
||||
|
||||
def label2int(targets):
|
||||
"""Convert string and/or integer labels to integer labels (necessary for
|
||||
torch)."""
|
||||
|
||||
_unique_targets = unique_targets(targets)
|
||||
targets = [_unique_targets[target] for target in targets]
|
||||
return targets
|
||||
|
||||
|
||||
def omega2lambda(omega: NDArray) -> NDArray:
|
||||
"""Normalize Ω (Omega) and return the CCM (Λ, Lambda)."""
|
||||
|
||||
omega = omega / np.sqrt(np.trace(omega.T @ omega))
|
||||
lmbd = omega @ omega.T
|
||||
return lmbd
|
||||
|
||||
|
||||
def vis_ccm(
|
||||
ccm: NDArray,
|
||||
labels: NDArray | list[str | int] | bool = True,
|
||||
title="Classification correlation matrix",
|
||||
cmap: str = "bwr",
|
||||
important_labels: int | None = None,
|
||||
no_diagonal: bool = False,
|
||||
show_cip: bool = True,
|
||||
show_crp: bool = False,
|
||||
sort: bool = False,
|
||||
cutout: int | None = None,
|
||||
fontsize: float = 6,
|
||||
filename: str | Path | None = None,
|
||||
):
|
||||
"""Visualise the Classification Correlation Matrix (CCM), Classification
|
||||
Influence Profile (CIP) and Classification Relevance Profile (CRP).
|
||||
|
||||
Parameters:
|
||||
-----------
|
||||
ccm : NDArray
|
||||
Classification Correlation Matrix (CCM) or Λ (Lambda) matrix.
|
||||
labels : NDArray or list of integers and/or strings or bool
|
||||
Supply a list of integers and/or strings. If `True`, generates an
|
||||
integer index. If `False`, doesn't show labels.
|
||||
title : str or None
|
||||
Title of the figure.
|
||||
cmap : str, default="bwr"
|
||||
For options of matplotlib colormaps, see:
|
||||
https://matplotlib.org/stable/tutorials/colors/colormaps.html
|
||||
important_labels : int, default=None
|
||||
Only show the `important_labels` most important labels.
|
||||
no_diagonal : bool, default=False
|
||||
Set the diagonal elements to zero.
|
||||
show_cip : bool, default=True
|
||||
Show the CIP.
|
||||
show_crp : bool, default=True
|
||||
Show the CRP.
|
||||
sort : bool, default=False
|
||||
Sort the features by their importance for classification (CIP).
|
||||
cutout : int or None, default=None
|
||||
If set, cuts out the `cutout` most important features and only shows
|
||||
those.
|
||||
fontsize : float, default=6
|
||||
Overall fontsize.
|
||||
filename : str or Path
|
||||
If a filename is given, the figure is saved to that path.
|
||||
"""
|
||||
|
||||
font = {"fontsize": fontsize}
|
||||
# profiles
|
||||
cip = np.abs(ccm).sum(axis=0) / len(ccm)
|
||||
crp = np.diagonal(ccm)
|
||||
if labels is True:
|
||||
ax_labels = list(range(len(ccm)))
|
||||
else:
|
||||
ax_labels = labels
|
||||
|
||||
if sort or cutout:
|
||||
if cutout is not None:
|
||||
greatest_importance = np.delete(np.argsort(cip)[::-1][:cutout], 2)
|
||||
else:
|
||||
greatest_importance = np.delete(np.argsort(cip)[::-1], 2)
|
||||
|
||||
new_ccm = np.zeros((len(greatest_importance), len(greatest_importance)))
|
||||
for i, x in enumerate(greatest_importance):
|
||||
for j, y in enumerate(greatest_importance):
|
||||
new_ccm[i, j] = ccm[x, y]
|
||||
ccm = new_ccm
|
||||
ax_labels = [
|
||||
ax_labels[i].replace("m_", "").replace("p_", "")
|
||||
for i in greatest_importance.tolist()
|
||||
]
|
||||
# ax_labels = [ax_labels[i] for i in greatest_importance.tolist()]
|
||||
|
||||
# profiles
|
||||
# cip = np.abs(ccm).sum(axis=0)
|
||||
# crp = np.diagonal(ccm)
|
||||
cip = cip[greatest_importance]
|
||||
crp = crp[greatest_importance]
|
||||
|
||||
if no_diagonal:
|
||||
np.fill_diagonal(ccm, 0)
|
||||
|
||||
if important_labels:
|
||||
greatest_importance = np.argsort(cip)[::-1][:important_labels]
|
||||
ax_labels = [
|
||||
l if i in greatest_importance else "" for i, l in enumerate(ax_labels)
|
||||
]
|
||||
if not labels:
|
||||
orig_labels = ax_labels
|
||||
ax_labels = [""] * len(ax_labels)
|
||||
|
||||
fig, ax = plt.subplots()
|
||||
|
||||
# matrix
|
||||
plt.title(title)
|
||||
|
||||
maxi = max(-np.amin(ccm), np.amax(ccm)) * 1.1
|
||||
vmin = -maxi
|
||||
vmax = maxi
|
||||
|
||||
cax = ax.imshow(ccm, interpolation="nearest", cmap=cmap, vmin=vmin, vmax=vmax)
|
||||
|
||||
# plt.xticks(np.arange(0, len(ax_labels)))
|
||||
plt.yticks(np.arange(0, len(ax_labels)))
|
||||
# ax.set_xticklabels([""] * len(ax_labels))
|
||||
ax.set_yticklabels(ax_labels, fontdict=font)
|
||||
if show_cip or show_crp:
|
||||
plt.xticks([])
|
||||
else:
|
||||
plt.xticks(np.arange(0, len(ax_labels)))
|
||||
ax.set_xticklabels(ax_labels, fontdict=font)
|
||||
|
||||
divider = make_axes_locatable(plt.gca())
|
||||
|
||||
# colorbar
|
||||
ax_color = divider.append_axes("right", "5%", pad="10%")
|
||||
plt.colorbar(cax, cax=ax_color)
|
||||
|
||||
if show_cip:
|
||||
ax_inf = divider.append_axes("bottom", "30%", pad="5%")
|
||||
ax_inf.plot(cip, marker=".", lw=0.6, ms=3, color="grey", label="CIP")
|
||||
if show_crp:
|
||||
ax_inf.plot(crp, marker=".", lw=0.6, ms=3, color="darkgrey", label="CRP")
|
||||
ax_inf.fill_between(range(len(ccm)), cip, crp, alpha=0.3, color="lightgrey")
|
||||
plt.legend(bbox_to_anchor=(1.43, 0.7), loc="center right")
|
||||
|
||||
if show_cip or show_crp:
|
||||
ax_inf.set_xticks(np.arange(0, len(ax_labels)))
|
||||
ax_inf.set_xticklabels(ax_labels, fontdict=font)
|
||||
plt.xticks(rotation=45, ha="right")
|
||||
ax_inf.set_ylim(bottom=0)
|
||||
|
||||
# plt.legend(bbox_to_anchor=(1.43, 0.7), loc="center right", fontsize=8)
|
||||
plt.tight_layout()
|
||||
if filename:
|
||||
plt.savefig(filename + ".png", dpi=600, bbox_inches="tight")
|
||||
plt.show()
|
||||
plt.close()
|
||||
|
||||
|
||||
def vis_data(
|
||||
data: NDArray,
|
||||
targets: NDArray | list[int | str] | None = None,
|
||||
prototypes: NDArray | None = None,
|
||||
prototype_labels: NDArray | list[int | str] | None = None,
|
||||
omega: NDArray | None = None,
|
||||
dimensions: int | None = None,
|
||||
title: str | None = None,
|
||||
size: float = 20,
|
||||
alpha: float = 0.5,
|
||||
cmap: str = "tab10",
|
||||
proto_marker: str = "o",
|
||||
proto_color: str = "r",
|
||||
filename: str | Path | None = None,
|
||||
):
|
||||
"""Visualise multidimensional data in one to three dimensions.
|
||||
Dimensionality reduction is achieved through either Principal Component
|
||||
Analysis (PCA) or, if the Ω (Omega) matrix is supplied, by projection into the
|
||||
latent space.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data : NDArray
|
||||
Data array.
|
||||
targets : NDArray or list of integers and/or strings or None, default=None
|
||||
Supply a list of targets (class labels) as integers and/or strings. If
|
||||
`None`, it doesn't show labels.
|
||||
prototypes : NDArray or None, default=None
|
||||
If prototype array is supplied, they are displayed and highlighted (see
|
||||
`proto_marker` and `proto_color`).
|
||||
prototype_labels : NDArray or list of integers and/or strings or None,
|
||||
default=None
|
||||
Supply a list of targets (class labels) as integers and/or strings. If
|
||||
`None`, it doesn't show labels.
|
||||
omega : NDArray or None, default=None
|
||||
Ω (Omega) matrix for projection into latent space (see General Matrix
|
||||
Learning Vector Quantization (GMLVQ)). If not supplied, PCA is chosen
|
||||
for dimensionality reduction.
|
||||
dimensions : int or None, default=None
|
||||
Desired number of projection dimensions. If not supplied, chooses
|
||||
either according to shape of Ω (Omega) or `2` for PCA.
|
||||
title : str
|
||||
Title of the figure.
|
||||
size : float, default=20
|
||||
Size of the points.
|
||||
alpha : float, default=0.5
|
||||
Alpha transparency for data points (e.g. for better visibility of
|
||||
prototypes).
|
||||
cmap : str, default="tab10"
|
||||
Where to choose colours for classes from. For options of matplotlib
|
||||
colormaps, see:
|
||||
https://matplotlib.org/stable/tutorials/colors/colormaps.html
|
||||
proto_marker : str, default="o"
|
||||
For more options, see:
|
||||
https://matplotlib.org/stable/api/markers_api.html#module-matplotlib.markers
|
||||
proto_color : str, default="r"
|
||||
Highlighting color for prototypes. For more options, see:
|
||||
https://matplotlib.org/stable/gallery/color/named_colors.html#base-colors
|
||||
filename : str or Path
|
||||
If a filename is given, the figure is saved to that path.
|
||||
"""
|
||||
|
||||
fig = plt.figure(1, figsize=(8, 6))
|
||||
n_data, n_dimensions = data.shape
|
||||
target_reverse = {i: target for i, target in enumerate(sorted(set(targets)))}
|
||||
method = ""
|
||||
|
||||
if omega is not None:
|
||||
assert 1 <= omega.shape[1] <= 3, (
|
||||
"Omega needs to be within the 3 dimensions "
|
||||
"that humans are accustomed to."
|
||||
)
|
||||
if dimensions is None:
|
||||
dimensions = omega.shape[1]
|
||||
data = data @ omega
|
||||
prototypes = prototypes @ omega
|
||||
method = "omega"
|
||||
if omega.shape[1] == 1:
|
||||
data = np.concatenate((data, np.zeros((data.shape[0], 1))), axis=1)
|
||||
prototypes = np.concatenate(
|
||||
(prototypes, np.zeros((prototypes.shape[0], 1))), axis=1
|
||||
)
|
||||
|
||||
if data.shape[1] >= 3:
|
||||
if dimensions is None:
|
||||
dimensions = 2
|
||||
pca = PCA(n_components=dimensions)
|
||||
data = pca.fit_transform(data)
|
||||
prototypes = pca.transform(prototypes)
|
||||
method = "PCA"
|
||||
|
||||
# display data
|
||||
if data.shape[1] <= 2:
|
||||
ax = fig.add_subplot(111)
|
||||
scatter = ax.scatter(
|
||||
*data.T,
|
||||
c=label2int(targets),
|
||||
marker="o",
|
||||
s=size,
|
||||
edgecolors="none",
|
||||
cmap=cmap,
|
||||
alpha=alpha,
|
||||
)
|
||||
|
||||
elif data.shape[1] == 3:
|
||||
ax = fig.add_subplot(111, projection="3d", elev=-150, azim=110)
|
||||
scatter = ax.scatter(
|
||||
*data.T,
|
||||
c=label2int(targets),
|
||||
marker="o",
|
||||
s=size,
|
||||
edgecolors="none",
|
||||
cmap=cmap,
|
||||
alpha=alpha,
|
||||
)
|
||||
|
||||
# display prototypes(?)
|
||||
if prototypes is not None and prototype_labels is None:
|
||||
warnings.warn(
|
||||
"You forgot to supply labels to the prototypes to " "color them correctly."
|
||||
)
|
||||
|
||||
if prototypes is not None and prototype_labels is not None:
|
||||
ax.scatter(
|
||||
*prototypes.T,
|
||||
c=label2int(prototype_labels),
|
||||
cmap=cmap,
|
||||
marker=proto_marker,
|
||||
s=size * 2,
|
||||
edgecolors=proto_color,
|
||||
linewidths=4,
|
||||
)
|
||||
|
||||
# axis labels
|
||||
if method == "PCA":
|
||||
if pca.n_components >= 1:
|
||||
ax.set_xlabel(
|
||||
"1st eigenvector " f"({pca.explained_variance_ratio_[0]:.2%})"
|
||||
)
|
||||
if pca.n_components >= 2:
|
||||
ax.set_ylabel(
|
||||
"2nd eigenvector " f"({pca.explained_variance_ratio_[1]:.2%})"
|
||||
)
|
||||
if pca.n_components == 3:
|
||||
ax.set_zlabel(
|
||||
f"3rd eigenvector " f"({pca.explained_variance_ratio_[2]:.2%})"
|
||||
)
|
||||
ax.set_title("PCA directions")
|
||||
elif method == "omega":
|
||||
if data.shape[1] >= 1:
|
||||
ax.set_xlabel("projection dimension 1")
|
||||
if data.shape[1] >= 2:
|
||||
ax.set_ylabel("projection dimension 2")
|
||||
if data.shape[1] == 3:
|
||||
ax.set_zlabel("projection dimension 3")
|
||||
ax.set_title("Omega transform")
|
||||
|
||||
cmap = mpl.colormaps[cmap].resampled(len(set(targets)))
|
||||
ax.legend(
|
||||
handles=[
|
||||
mpatches.Patch(color=cmap(t, alpha=alpha), label=target_reverse[t])
|
||||
for t in target_reverse
|
||||
]
|
||||
)
|
||||
if filename:
|
||||
plt.savefig(filename + ".png", dpi=600, bbox_inches="tight")
|
||||
plt.show()
|
Loading…
Reference in New Issue
Block a user