[BUGFIX] Fix siamese visualization callback
This commit is contained in:
parent
b7684ae512
commit
6e7d80be88
@ -6,11 +6,12 @@ import torch
|
|||||||
import torchvision
|
import torchvision
|
||||||
from matplotlib import pyplot as plt
|
from matplotlib import pyplot as plt
|
||||||
from matplotlib.offsetbox import AnchoredText
|
from matplotlib.offsetbox import AnchoredText
|
||||||
|
from torch.utils.data import DataLoader, Dataset
|
||||||
|
|
||||||
from prototorch.utils.celluloid import Camera
|
from prototorch.utils.celluloid import Camera
|
||||||
from prototorch.utils.colors import color_scheme
|
from prototorch.utils.colors import color_scheme
|
||||||
from prototorch.utils.utils import (gif_from_dir, make_directory,
|
from prototorch.utils.utils import (gif_from_dir, make_directory,
|
||||||
prettify_string)
|
prettify_string)
|
||||||
from torch.utils.data import DataLoader, Dataset
|
|
||||||
|
|
||||||
|
|
||||||
class VisWeights(pl.Callback):
|
class VisWeights(pl.Callback):
|
||||||
@ -407,17 +408,17 @@ class VisSiameseGLVQ2D(Vis2DAbstract):
|
|||||||
torch.Tensor(protos).to(pl_module.device)).cpu().detach()
|
torch.Tensor(protos).to(pl_module.device)).cpu().detach()
|
||||||
ax = self.setup_ax()
|
ax = self.setup_ax()
|
||||||
self.plot_data(ax, x_train, y_train)
|
self.plot_data(ax, x_train, y_train)
|
||||||
#if self.show_protos:
|
if self.show_protos:
|
||||||
# self.plot_protos(ax, protos, plabels)
|
self.plot_protos(ax, protos, plabels)
|
||||||
# x = np.vstack((x_train, protos))
|
x = np.vstack((x_train, protos))
|
||||||
# mesh_input, xx, yy = self.get_mesh_input(x)
|
mesh_input, xx, yy = self.get_mesh_input(x)
|
||||||
#else:
|
else:
|
||||||
# mesh_input, xx, yy = self.get_mesh_input(x_train)
|
mesh_input, xx, yy = self.get_mesh_input(x_train)
|
||||||
#_components = pl_module.proto_layer._components
|
_components = pl_module.proto_layer._components
|
||||||
#y_pred = pl_module.predict(
|
y_pred = pl_module.predict_latent(
|
||||||
# torch.Tensor(mesh_input).type_as(_components))
|
torch.Tensor(mesh_input).type_as(_components))
|
||||||
#y_pred = y_pred.cpu().reshape(xx.shape)
|
y_pred = y_pred.cpu().reshape(xx.shape)
|
||||||
#ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
|
ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
|
||||||
|
|
||||||
self.log_and_display(trainer, pl_module)
|
self.log_and_display(trainer, pl_module)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user