diff --git a/prototorch/models/vis.py b/prototorch/models/vis.py index dfedfe5..4f6b696 100644 --- a/prototorch/models/vis.py +++ b/prototorch/models/vis.py @@ -7,6 +7,8 @@ import torchvision from matplotlib import pyplot as plt from torch.utils.data import DataLoader, Dataset +from ..utils.utils import mesh2d + class Vis2DAbstract(pl.Callback): def __init__(self, @@ -73,16 +75,6 @@ class Vis2DAbstract(pl.Callback): ax.axis("off") return ax - def get_mesh_input(self, x): - x_shift = self.border * np.ptp(x[:, 0]) - y_shift = self.border * np.ptp(x[:, 1]) - x_min, x_max = x[:, 0].min() - x_shift, x[:, 0].max() + x_shift - y_min, y_max = x[:, 1].min() - y_shift, x[:, 1].max() + y_shift - xx, yy = np.meshgrid(np.linspace(x_min, x_max, self.resolution), - np.linspace(y_min, y_max, self.resolution)) - mesh_input = np.c_[xx.ravel(), yy.ravel()] - return mesh_input, xx, yy - def plot_data(self, ax, x, y): ax.scatter( x[:, 0], @@ -138,7 +130,7 @@ class VisGLVQ2D(Vis2DAbstract): self.plot_data(ax, x_train, y_train) self.plot_protos(ax, protos, plabels) x = np.vstack((x_train, protos)) - mesh_input, xx, yy = self.get_mesh_input(x) + mesh_input, xx, yy = mesh2d(x, self.border, self.resolution) _components = pl_module.proto_layer._components mesh_input = torch.from_numpy(mesh_input).type_as(_components) y_pred = pl_module.predict(mesh_input) @@ -173,9 +165,9 @@ class VisSiameseGLVQ2D(Vis2DAbstract): if self.show_protos: self.plot_protos(ax, protos, plabels) x = np.vstack((x_train, protos)) - mesh_input, xx, yy = self.get_mesh_input(x) + mesh_input, xx, yy = mesh2d(x, self.border, self.resolution) else: - mesh_input, xx, yy = self.get_mesh_input(x_train) + mesh_input, xx, yy = mesh2d(x_train, self.border, self.resolution) _components = pl_module.proto_layer._components mesh_input = torch.Tensor(mesh_input).type_as(_components) y_pred = pl_module.predict_latent(mesh_input, @@ -198,7 +190,7 @@ class VisCBC2D(Vis2DAbstract): self.plot_data(ax, x_train, y_train) self.plot_protos(ax, protos, "w") x = np.vstack((x_train, protos)) - mesh_input, xx, yy = self.get_mesh_input(x) + mesh_input, xx, yy = mesh2d(x, self.border, self.resolution) _components = pl_module.components_layer._components y_pred = pl_module.predict( torch.Tensor(mesh_input).type_as(_components))