diff --git a/prototorch/models/vis.py b/prototorch/models/vis.py index 697421a..2018ec9 100644 --- a/prototorch/models/vis.py +++ b/prototorch/models/vis.py @@ -7,6 +7,7 @@ import torchvision from matplotlib import pyplot as plt from torch.utils.data import DataLoader, Dataset +from ..utils.colors import get_colors, get_legend_handles from ..utils.utils import mesh2d @@ -18,6 +19,7 @@ class Vis2DAbstract(pl.Callback): cmap="viridis", xlabel="Data dimension 1", ylabel="Data dimension 2", + legend_labels=None, border=0.1, resolution=100, flatten_data=True, @@ -50,6 +52,7 @@ class Vis2DAbstract(pl.Callback): self.title = title self.xlabel = xlabel self.ylabel = ylabel + self.legend_labels = legend_labels self.fig = plt.figure(self.title) self.cmap = cmap self.border = border @@ -249,6 +252,24 @@ class VisNG2D(Vis2DAbstract): ) +class VisSpectralProtos(Vis2DAbstract): + + def visualize(self, pl_module): + protos = pl_module.prototypes + plabels = pl_module.prototype_labels + ax = self.setup_ax() + colors = get_colors(vmax=max(plabels), vmin=min(plabels)) + for p, pl in zip(protos, plabels): + ax.plot(p, c=colors[int(pl)]) + if self.legend_labels: + handles = get_legend_handles( + colors, + self.legend_labels, + marker="lines", + ) + ax.legend(handles=handles) + + class VisImgComp(Vis2DAbstract): def __init__(self,