[BUGFIX] Fix siamese visualization callback

This commit is contained in:
Jensun Ravichandran 2021-05-15 12:52:44 +02:00
parent b7684ae512
commit 6e7d80be88

View File

@ -6,11 +6,12 @@ import torch
import torchvision import torchvision
from matplotlib import pyplot as plt from matplotlib import pyplot as plt
from matplotlib.offsetbox import AnchoredText from matplotlib.offsetbox import AnchoredText
from torch.utils.data import DataLoader, Dataset
from prototorch.utils.celluloid import Camera from prototorch.utils.celluloid import Camera
from prototorch.utils.colors import color_scheme from prototorch.utils.colors import color_scheme
from prototorch.utils.utils import (gif_from_dir, make_directory, from prototorch.utils.utils import (gif_from_dir, make_directory,
prettify_string) prettify_string)
from torch.utils.data import DataLoader, Dataset
class VisWeights(pl.Callback): class VisWeights(pl.Callback):
@ -407,17 +408,17 @@ class VisSiameseGLVQ2D(Vis2DAbstract):
torch.Tensor(protos).to(pl_module.device)).cpu().detach() torch.Tensor(protos).to(pl_module.device)).cpu().detach()
ax = self.setup_ax() ax = self.setup_ax()
self.plot_data(ax, x_train, y_train) self.plot_data(ax, x_train, y_train)
#if self.show_protos: if self.show_protos:
# self.plot_protos(ax, protos, plabels) self.plot_protos(ax, protos, plabels)
# x = np.vstack((x_train, protos)) x = np.vstack((x_train, protos))
# mesh_input, xx, yy = self.get_mesh_input(x) mesh_input, xx, yy = self.get_mesh_input(x)
#else: else:
# mesh_input, xx, yy = self.get_mesh_input(x_train) mesh_input, xx, yy = self.get_mesh_input(x_train)
#_components = pl_module.proto_layer._components _components = pl_module.proto_layer._components
#y_pred = pl_module.predict( y_pred = pl_module.predict_latent(
# torch.Tensor(mesh_input).type_as(_components)) torch.Tensor(mesh_input).type_as(_components))
#y_pred = y_pred.cpu().reshape(xx.shape) y_pred = y_pred.cpu().reshape(xx.shape)
#ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35) ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
self.log_and_display(trainer, pl_module) self.log_and_display(trainer, pl_module)