feat: add VisSpectralProtos
This commit is contained in:
parent
b31c8cc707
commit
ce14dec7e9
@ -7,6 +7,7 @@ import torchvision
|
||||
from matplotlib import pyplot as plt
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
|
||||
from ..utils.colors import get_colors, get_legend_handles
|
||||
from ..utils.utils import mesh2d
|
||||
|
||||
|
||||
@ -18,6 +19,7 @@ class Vis2DAbstract(pl.Callback):
|
||||
cmap="viridis",
|
||||
xlabel="Data dimension 1",
|
||||
ylabel="Data dimension 2",
|
||||
legend_labels=None,
|
||||
border=0.1,
|
||||
resolution=100,
|
||||
flatten_data=True,
|
||||
@ -50,6 +52,7 @@ class Vis2DAbstract(pl.Callback):
|
||||
self.title = title
|
||||
self.xlabel = xlabel
|
||||
self.ylabel = ylabel
|
||||
self.legend_labels = legend_labels
|
||||
self.fig = plt.figure(self.title)
|
||||
self.cmap = cmap
|
||||
self.border = border
|
||||
@ -249,6 +252,24 @@ class VisNG2D(Vis2DAbstract):
|
||||
)
|
||||
|
||||
|
||||
class VisSpectralProtos(Vis2DAbstract):
|
||||
|
||||
def visualize(self, pl_module):
|
||||
protos = pl_module.prototypes
|
||||
plabels = pl_module.prototype_labels
|
||||
ax = self.setup_ax()
|
||||
colors = get_colors(vmax=max(plabels), vmin=min(plabels))
|
||||
for p, pl in zip(protos, plabels):
|
||||
ax.plot(p, c=colors[int(pl)])
|
||||
if self.legend_labels:
|
||||
handles = get_legend_handles(
|
||||
colors,
|
||||
self.legend_labels,
|
||||
marker="lines",
|
||||
)
|
||||
ax.legend(handles=handles)
|
||||
|
||||
|
||||
class VisImgComp(Vis2DAbstract):
|
||||
|
||||
def __init__(self,
|
||||
|
Loading…
Reference in New Issue
Block a user