feat: add VisSpectralProtos

This commit is contained in:
Jensun Ravichandran 2022-03-10 15:24:44 +01:00
parent b31c8cc707
commit ce14dec7e9
No known key found for this signature in database
GPG Key ID: 4E9348239810B51F

View File

@ -7,6 +7,7 @@ import torchvision
from matplotlib import pyplot as plt from matplotlib import pyplot as plt
from torch.utils.data import DataLoader, Dataset from torch.utils.data import DataLoader, Dataset
from ..utils.colors import get_colors, get_legend_handles
from ..utils.utils import mesh2d from ..utils.utils import mesh2d
@ -18,6 +19,7 @@ class Vis2DAbstract(pl.Callback):
cmap="viridis", cmap="viridis",
xlabel="Data dimension 1", xlabel="Data dimension 1",
ylabel="Data dimension 2", ylabel="Data dimension 2",
legend_labels=None,
border=0.1, border=0.1,
resolution=100, resolution=100,
flatten_data=True, flatten_data=True,
@ -50,6 +52,7 @@ class Vis2DAbstract(pl.Callback):
self.title = title self.title = title
self.xlabel = xlabel self.xlabel = xlabel
self.ylabel = ylabel self.ylabel = ylabel
self.legend_labels = legend_labels
self.fig = plt.figure(self.title) self.fig = plt.figure(self.title)
self.cmap = cmap self.cmap = cmap
self.border = border self.border = border
@ -249,6 +252,24 @@ class VisNG2D(Vis2DAbstract):
) )
class VisSpectralProtos(Vis2DAbstract):
def visualize(self, pl_module):
protos = pl_module.prototypes
plabels = pl_module.prototype_labels
ax = self.setup_ax()
colors = get_colors(vmax=max(plabels), vmin=min(plabels))
for p, pl in zip(protos, plabels):
ax.plot(p, c=colors[int(pl)])
if self.legend_labels:
handles = get_legend_handles(
colors,
self.legend_labels,
marker="lines",
)
ax.legend(handles=handles)
class VisImgComp(Vis2DAbstract): class VisImgComp(Vis2DAbstract):
def __init__(self, def __init__(self,