Initial commit.

Includes: visualiser.py with documentation for visualising the CCM and
data.
This commit is contained in:
juvilius 2023-08-29 11:25:53 +02:00
parent f65705e8c2
commit 76668e552c
Signed by untrusted user who does not match committer: julius
GPG Key ID: 3EAC91A848E1D685

343
visualiser.py Normal file
View File

@ -0,0 +1,343 @@
import warnings
from pathlib import Path
import matplotlib as mpl
import matplotlib.patches as mpatches
import matplotlib.pyplot as plt
import numpy as np
from mpl_toolkits.axes_grid1.axes_divider import make_axes_locatable
from numpy.typing import NDArray
from sklearn.decomposition import PCA
def unique_targets(targets):
"""Generate dictionary of integer indices and the corresponding targets."""
return {target: i for i, target in enumerate(sorted(set(targets)))}
def label2int(targets):
"""Convert string and/or integer labels to integer labels (necessary for
torch)."""
_unique_targets = unique_targets(targets)
targets = [_unique_targets[target] for target in targets]
return targets
def omega2lambda(omega: NDArray) -> NDArray:
"""Normalize Ω (Omega) and return the CCM (Λ, Lambda)."""
omega = omega / np.sqrt(np.trace(omega.T @ omega))
lmbd = omega @ omega.T
return lmbd
def vis_ccm(
ccm: NDArray,
labels: NDArray | list[str | int] | bool = True,
title="Classification correlation matrix",
cmap: str = "bwr",
important_labels: int | None = None,
no_diagonal: bool = False,
show_cip: bool = True,
show_crp: bool = False,
sort: bool = False,
cutout: int | None = None,
fontsize: float = 6,
filename: str | Path | None = None,
):
"""Visualise the Classification Correlation Matrix (CCM), Classification
Influence Profile (CIP) and Classification Relevance Profile (CRP).
Parameters:
-----------
ccm : NDArray
Classification Correlation Matrix (CCM) or Λ (Lambda) matrix.
labels : NDArray or list of integers and/or strings or bool
Supply a list of integers and/or strings. If `True`, generates an
integer index. If `False`, doesn't show labels.
title : str or None
Title of the figure.
cmap : str, default="bwr"
For options of matplotlib colormaps, see:
https://matplotlib.org/stable/tutorials/colors/colormaps.html
important_labels : int, default=None
Only show the `important_labels` most important labels.
no_diagonal : bool, default=False
Set the diagonal elements to zero.
show_cip : bool, default=True
Show the CIP.
show_crp : bool, default=True
Show the CRP.
sort : bool, default=False
Sort the features by their importance for classification (CIP).
cutout : int or None, default=None
If set, cuts out the `cutout` most important features and only shows
those.
fontsize : float, default=6
Overall fontsize.
filename : str or Path
If a filename is given, the figure is saved to that path.
"""
font = {"fontsize": fontsize}
# profiles
cip = np.abs(ccm).sum(axis=0) / len(ccm)
crp = np.diagonal(ccm)
if labels is True:
ax_labels = list(range(len(ccm)))
else:
ax_labels = labels
if sort or cutout:
if cutout is not None:
greatest_importance = np.delete(np.argsort(cip)[::-1][:cutout], 2)
else:
greatest_importance = np.delete(np.argsort(cip)[::-1], 2)
new_ccm = np.zeros((len(greatest_importance), len(greatest_importance)))
for i, x in enumerate(greatest_importance):
for j, y in enumerate(greatest_importance):
new_ccm[i, j] = ccm[x, y]
ccm = new_ccm
ax_labels = [
ax_labels[i].replace("m_", "").replace("p_", "")
for i in greatest_importance.tolist()
]
# ax_labels = [ax_labels[i] for i in greatest_importance.tolist()]
# profiles
# cip = np.abs(ccm).sum(axis=0)
# crp = np.diagonal(ccm)
cip = cip[greatest_importance]
crp = crp[greatest_importance]
if no_diagonal:
np.fill_diagonal(ccm, 0)
if important_labels:
greatest_importance = np.argsort(cip)[::-1][:important_labels]
ax_labels = [
l if i in greatest_importance else "" for i, l in enumerate(ax_labels)
]
if not labels:
orig_labels = ax_labels
ax_labels = [""] * len(ax_labels)
fig, ax = plt.subplots()
# matrix
plt.title(title)
maxi = max(-np.amin(ccm), np.amax(ccm)) * 1.1
vmin = -maxi
vmax = maxi
cax = ax.imshow(ccm, interpolation="nearest", cmap=cmap, vmin=vmin, vmax=vmax)
# plt.xticks(np.arange(0, len(ax_labels)))
plt.yticks(np.arange(0, len(ax_labels)))
# ax.set_xticklabels([""] * len(ax_labels))
ax.set_yticklabels(ax_labels, fontdict=font)
if show_cip or show_crp:
plt.xticks([])
else:
plt.xticks(np.arange(0, len(ax_labels)))
ax.set_xticklabels(ax_labels, fontdict=font)
divider = make_axes_locatable(plt.gca())
# colorbar
ax_color = divider.append_axes("right", "5%", pad="10%")
plt.colorbar(cax, cax=ax_color)
if show_cip:
ax_inf = divider.append_axes("bottom", "30%", pad="5%")
ax_inf.plot(cip, marker=".", lw=0.6, ms=3, color="grey", label="CIP")
if show_crp:
ax_inf.plot(crp, marker=".", lw=0.6, ms=3, color="darkgrey", label="CRP")
ax_inf.fill_between(range(len(ccm)), cip, crp, alpha=0.3, color="lightgrey")
plt.legend(bbox_to_anchor=(1.43, 0.7), loc="center right")
if show_cip or show_crp:
ax_inf.set_xticks(np.arange(0, len(ax_labels)))
ax_inf.set_xticklabels(ax_labels, fontdict=font)
plt.xticks(rotation=45, ha="right")
ax_inf.set_ylim(bottom=0)
# plt.legend(bbox_to_anchor=(1.43, 0.7), loc="center right", fontsize=8)
plt.tight_layout()
if filename:
plt.savefig(filename + ".png", dpi=600, bbox_inches="tight")
plt.show()
plt.close()
def vis_data(
data: NDArray,
targets: NDArray | list[int | str] | None = None,
prototypes: NDArray | None = None,
prototype_labels: NDArray | list[int | str] | None = None,
omega: NDArray | None = None,
dimensions: int | None = None,
title: str | None = None,
size: float = 20,
alpha: float = 0.5,
cmap: str = "tab10",
proto_marker: str = "o",
proto_color: str = "r",
filename: str | Path | None = None,
):
"""Visualise multidimensional data in one to three dimensions.
Dimensionality reduction is achieved through either Principal Component
Analysis (PCA) or, if the Ω (Omega) matrix is supplied, by projection into the
latent space.
Parameters
----------
data : NDArray
Data array.
targets : NDArray or list of integers and/or strings or None, default=None
Supply a list of targets (class labels) as integers and/or strings. If
`None`, it doesn't show labels.
prototypes : NDArray or None, default=None
If prototype array is supplied, they are displayed and highlighted (see
`proto_marker` and `proto_color`).
prototype_labels : NDArray or list of integers and/or strings or None,
default=None
Supply a list of targets (class labels) as integers and/or strings. If
`None`, it doesn't show labels.
omega : NDArray or None, default=None
Ω (Omega) matrix for projection into latent space (see General Matrix
Learning Vector Quantization (GMLVQ)). If not supplied, PCA is chosen
for dimensionality reduction.
dimensions : int or None, default=None
Desired number of projection dimensions. If not supplied, chooses
either according to shape of Ω (Omega) or `2` for PCA.
title : str
Title of the figure.
size : float, default=20
Size of the points.
alpha : float, default=0.5
Alpha transparency for data points (e.g. for better visibility of
prototypes).
cmap : str, default="tab10"
Where to choose colours for classes from. For options of matplotlib
colormaps, see:
https://matplotlib.org/stable/tutorials/colors/colormaps.html
proto_marker : str, default="o"
For more options, see:
https://matplotlib.org/stable/api/markers_api.html#module-matplotlib.markers
proto_color : str, default="r"
Highlighting color for prototypes. For more options, see:
https://matplotlib.org/stable/gallery/color/named_colors.html#base-colors
filename : str or Path
If a filename is given, the figure is saved to that path.
"""
fig = plt.figure(1, figsize=(8, 6))
n_data, n_dimensions = data.shape
target_reverse = {i: target for i, target in enumerate(sorted(set(targets)))}
method = ""
if omega is not None:
assert 1 <= omega.shape[1] <= 3, (
"Omega needs to be within the 3 dimensions "
"that humans are accustomed to."
)
if dimensions is None:
dimensions = omega.shape[1]
data = data @ omega
prototypes = prototypes @ omega
method = "omega"
if omega.shape[1] == 1:
data = np.concatenate((data, np.zeros((data.shape[0], 1))), axis=1)
prototypes = np.concatenate(
(prototypes, np.zeros((prototypes.shape[0], 1))), axis=1
)
if data.shape[1] >= 3:
if dimensions is None:
dimensions = 2
pca = PCA(n_components=dimensions)
data = pca.fit_transform(data)
prototypes = pca.transform(prototypes)
method = "PCA"
# display data
if data.shape[1] <= 2:
ax = fig.add_subplot(111)
scatter = ax.scatter(
*data.T,
c=label2int(targets),
marker="o",
s=size,
edgecolors="none",
cmap=cmap,
alpha=alpha,
)
elif data.shape[1] == 3:
ax = fig.add_subplot(111, projection="3d", elev=-150, azim=110)
scatter = ax.scatter(
*data.T,
c=label2int(targets),
marker="o",
s=size,
edgecolors="none",
cmap=cmap,
alpha=alpha,
)
# display prototypes(?)
if prototypes is not None and prototype_labels is None:
warnings.warn(
"You forgot to supply labels to the prototypes to " "color them correctly."
)
if prototypes is not None and prototype_labels is not None:
ax.scatter(
*prototypes.T,
c=label2int(prototype_labels),
cmap=cmap,
marker=proto_marker,
s=size * 2,
edgecolors=proto_color,
linewidths=4,
)
# axis labels
if method == "PCA":
if pca.n_components >= 1:
ax.set_xlabel(
"1st eigenvector " f"({pca.explained_variance_ratio_[0]:.2%})"
)
if pca.n_components >= 2:
ax.set_ylabel(
"2nd eigenvector " f"({pca.explained_variance_ratio_[1]:.2%})"
)
if pca.n_components == 3:
ax.set_zlabel(
f"3rd eigenvector " f"({pca.explained_variance_ratio_[2]:.2%})"
)
ax.set_title("PCA directions")
elif method == "omega":
if data.shape[1] >= 1:
ax.set_xlabel("projection dimension 1")
if data.shape[1] >= 2:
ax.set_ylabel("projection dimension 2")
if data.shape[1] == 3:
ax.set_zlabel("projection dimension 3")
ax.set_title("Omega transform")
cmap = mpl.colormaps[cmap].resampled(len(set(targets)))
ax.legend(
handles=[
mpatches.Patch(color=cmap(t, alpha=alpha), label=target_reverse[t])
for t in target_reverse
]
)
if filename:
plt.savefig(filename + ".png", dpi=600, bbox_inches="tight")
plt.show()