feat: add VisSpectralProtos
This commit is contained in:
parent
b31c8cc707
commit
ce14dec7e9
@ -7,6 +7,7 @@ import torchvision
|
|||||||
from matplotlib import pyplot as plt
|
from matplotlib import pyplot as plt
|
||||||
from torch.utils.data import DataLoader, Dataset
|
from torch.utils.data import DataLoader, Dataset
|
||||||
|
|
||||||
|
from ..utils.colors import get_colors, get_legend_handles
|
||||||
from ..utils.utils import mesh2d
|
from ..utils.utils import mesh2d
|
||||||
|
|
||||||
|
|
||||||
@ -18,6 +19,7 @@ class Vis2DAbstract(pl.Callback):
|
|||||||
cmap="viridis",
|
cmap="viridis",
|
||||||
xlabel="Data dimension 1",
|
xlabel="Data dimension 1",
|
||||||
ylabel="Data dimension 2",
|
ylabel="Data dimension 2",
|
||||||
|
legend_labels=None,
|
||||||
border=0.1,
|
border=0.1,
|
||||||
resolution=100,
|
resolution=100,
|
||||||
flatten_data=True,
|
flatten_data=True,
|
||||||
@ -50,6 +52,7 @@ class Vis2DAbstract(pl.Callback):
|
|||||||
self.title = title
|
self.title = title
|
||||||
self.xlabel = xlabel
|
self.xlabel = xlabel
|
||||||
self.ylabel = ylabel
|
self.ylabel = ylabel
|
||||||
|
self.legend_labels = legend_labels
|
||||||
self.fig = plt.figure(self.title)
|
self.fig = plt.figure(self.title)
|
||||||
self.cmap = cmap
|
self.cmap = cmap
|
||||||
self.border = border
|
self.border = border
|
||||||
@ -249,6 +252,24 @@ class VisNG2D(Vis2DAbstract):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class VisSpectralProtos(Vis2DAbstract):
|
||||||
|
|
||||||
|
def visualize(self, pl_module):
|
||||||
|
protos = pl_module.prototypes
|
||||||
|
plabels = pl_module.prototype_labels
|
||||||
|
ax = self.setup_ax()
|
||||||
|
colors = get_colors(vmax=max(plabels), vmin=min(plabels))
|
||||||
|
for p, pl in zip(protos, plabels):
|
||||||
|
ax.plot(p, c=colors[int(pl)])
|
||||||
|
if self.legend_labels:
|
||||||
|
handles = get_legend_handles(
|
||||||
|
colors,
|
||||||
|
self.legend_labels,
|
||||||
|
marker="lines",
|
||||||
|
)
|
||||||
|
ax.legend(handles=handles)
|
||||||
|
|
||||||
|
|
||||||
class VisImgComp(Vis2DAbstract):
|
class VisImgComp(Vis2DAbstract):
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
|
Loading…
Reference in New Issue
Block a user