[BUGFIX] Fix siamese visualization callback
This commit is contained in:
		@@ -6,11 +6,12 @@ import torch
 | 
				
			|||||||
import torchvision
 | 
					import torchvision
 | 
				
			||||||
from matplotlib import pyplot as plt
 | 
					from matplotlib import pyplot as plt
 | 
				
			||||||
from matplotlib.offsetbox import AnchoredText
 | 
					from matplotlib.offsetbox import AnchoredText
 | 
				
			||||||
 | 
					from torch.utils.data import DataLoader, Dataset
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from prototorch.utils.celluloid import Camera
 | 
					from prototorch.utils.celluloid import Camera
 | 
				
			||||||
from prototorch.utils.colors import color_scheme
 | 
					from prototorch.utils.colors import color_scheme
 | 
				
			||||||
from prototorch.utils.utils import (gif_from_dir, make_directory,
 | 
					from prototorch.utils.utils import (gif_from_dir, make_directory,
 | 
				
			||||||
                                    prettify_string)
 | 
					                                    prettify_string)
 | 
				
			||||||
from torch.utils.data import DataLoader, Dataset
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class VisWeights(pl.Callback):
 | 
					class VisWeights(pl.Callback):
 | 
				
			||||||
@@ -407,17 +408,17 @@ class VisSiameseGLVQ2D(Vis2DAbstract):
 | 
				
			|||||||
                torch.Tensor(protos).to(pl_module.device)).cpu().detach()
 | 
					                torch.Tensor(protos).to(pl_module.device)).cpu().detach()
 | 
				
			||||||
        ax = self.setup_ax()
 | 
					        ax = self.setup_ax()
 | 
				
			||||||
        self.plot_data(ax, x_train, y_train)
 | 
					        self.plot_data(ax, x_train, y_train)
 | 
				
			||||||
        #if self.show_protos:
 | 
					        if self.show_protos:
 | 
				
			||||||
        #    self.plot_protos(ax, protos, plabels)
 | 
					            self.plot_protos(ax, protos, plabels)
 | 
				
			||||||
        #    x = np.vstack((x_train, protos))
 | 
					            x = np.vstack((x_train, protos))
 | 
				
			||||||
        #    mesh_input, xx, yy = self.get_mesh_input(x)
 | 
					            mesh_input, xx, yy = self.get_mesh_input(x)
 | 
				
			||||||
        #else:
 | 
					        else:
 | 
				
			||||||
        #    mesh_input, xx, yy = self.get_mesh_input(x_train)
 | 
					            mesh_input, xx, yy = self.get_mesh_input(x_train)
 | 
				
			||||||
        #_components = pl_module.proto_layer._components
 | 
					        _components = pl_module.proto_layer._components
 | 
				
			||||||
        #y_pred = pl_module.predict(
 | 
					        y_pred = pl_module.predict_latent(
 | 
				
			||||||
        #    torch.Tensor(mesh_input).type_as(_components))
 | 
					            torch.Tensor(mesh_input).type_as(_components))
 | 
				
			||||||
        #y_pred = y_pred.cpu().reshape(xx.shape)
 | 
					        y_pred = y_pred.cpu().reshape(xx.shape)
 | 
				
			||||||
        #ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
 | 
					        ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        self.log_and_display(trainer, pl_module)
 | 
					        self.log_and_display(trainer, pl_module)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user