feat: add VisSpectralProtos
				
					
				
			This commit is contained in:
		@@ -7,6 +7,7 @@ import torchvision
 | 
				
			|||||||
from matplotlib import pyplot as plt
 | 
					from matplotlib import pyplot as plt
 | 
				
			||||||
from torch.utils.data import DataLoader, Dataset
 | 
					from torch.utils.data import DataLoader, Dataset
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from ..utils.colors import get_colors, get_legend_handles
 | 
				
			||||||
from ..utils.utils import mesh2d
 | 
					from ..utils.utils import mesh2d
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -18,6 +19,7 @@ class Vis2DAbstract(pl.Callback):
 | 
				
			|||||||
                 cmap="viridis",
 | 
					                 cmap="viridis",
 | 
				
			||||||
                 xlabel="Data dimension 1",
 | 
					                 xlabel="Data dimension 1",
 | 
				
			||||||
                 ylabel="Data dimension 2",
 | 
					                 ylabel="Data dimension 2",
 | 
				
			||||||
 | 
					                 legend_labels=None,
 | 
				
			||||||
                 border=0.1,
 | 
					                 border=0.1,
 | 
				
			||||||
                 resolution=100,
 | 
					                 resolution=100,
 | 
				
			||||||
                 flatten_data=True,
 | 
					                 flatten_data=True,
 | 
				
			||||||
@@ -50,6 +52,7 @@ class Vis2DAbstract(pl.Callback):
 | 
				
			|||||||
        self.title = title
 | 
					        self.title = title
 | 
				
			||||||
        self.xlabel = xlabel
 | 
					        self.xlabel = xlabel
 | 
				
			||||||
        self.ylabel = ylabel
 | 
					        self.ylabel = ylabel
 | 
				
			||||||
 | 
					        self.legend_labels = legend_labels
 | 
				
			||||||
        self.fig = plt.figure(self.title)
 | 
					        self.fig = plt.figure(self.title)
 | 
				
			||||||
        self.cmap = cmap
 | 
					        self.cmap = cmap
 | 
				
			||||||
        self.border = border
 | 
					        self.border = border
 | 
				
			||||||
@@ -249,6 +252,24 @@ class VisNG2D(Vis2DAbstract):
 | 
				
			|||||||
                    )
 | 
					                    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class VisSpectralProtos(Vis2DAbstract):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def visualize(self, pl_module):
 | 
				
			||||||
 | 
					        protos = pl_module.prototypes
 | 
				
			||||||
 | 
					        plabels = pl_module.prototype_labels
 | 
				
			||||||
 | 
					        ax = self.setup_ax()
 | 
				
			||||||
 | 
					        colors = get_colors(vmax=max(plabels), vmin=min(plabels))
 | 
				
			||||||
 | 
					        for p, pl in zip(protos, plabels):
 | 
				
			||||||
 | 
					            ax.plot(p, c=colors[int(pl)])
 | 
				
			||||||
 | 
					        if self.legend_labels:
 | 
				
			||||||
 | 
					            handles = get_legend_handles(
 | 
				
			||||||
 | 
					                colors,
 | 
				
			||||||
 | 
					                self.legend_labels,
 | 
				
			||||||
 | 
					                marker="lines",
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					            ax.legend(handles=handles)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class VisImgComp(Vis2DAbstract):
 | 
					class VisImgComp(Vis2DAbstract):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __init__(self,
 | 
					    def __init__(self,
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user