Use mesh2d from prototorch.utils

This commit is contained in:
Jensun Ravichandran 2021-06-18 13:43:44 +02:00
parent 4ab0a5a414
commit 7eb496110f

View File

@ -7,6 +7,8 @@ import torchvision
from matplotlib import pyplot as plt from matplotlib import pyplot as plt
from torch.utils.data import DataLoader, Dataset from torch.utils.data import DataLoader, Dataset
from ..utils.utils import mesh2d
class Vis2DAbstract(pl.Callback): class Vis2DAbstract(pl.Callback):
def __init__(self, def __init__(self,
@ -73,16 +75,6 @@ class Vis2DAbstract(pl.Callback):
ax.axis("off") ax.axis("off")
return ax return ax
def get_mesh_input(self, x):
x_shift = self.border * np.ptp(x[:, 0])
y_shift = self.border * np.ptp(x[:, 1])
x_min, x_max = x[:, 0].min() - x_shift, x[:, 0].max() + x_shift
y_min, y_max = x[:, 1].min() - y_shift, x[:, 1].max() + y_shift
xx, yy = np.meshgrid(np.linspace(x_min, x_max, self.resolution),
np.linspace(y_min, y_max, self.resolution))
mesh_input = np.c_[xx.ravel(), yy.ravel()]
return mesh_input, xx, yy
def plot_data(self, ax, x, y): def plot_data(self, ax, x, y):
ax.scatter( ax.scatter(
x[:, 0], x[:, 0],
@ -138,7 +130,7 @@ class VisGLVQ2D(Vis2DAbstract):
self.plot_data(ax, x_train, y_train) self.plot_data(ax, x_train, y_train)
self.plot_protos(ax, protos, plabels) self.plot_protos(ax, protos, plabels)
x = np.vstack((x_train, protos)) x = np.vstack((x_train, protos))
mesh_input, xx, yy = self.get_mesh_input(x) mesh_input, xx, yy = mesh2d(x, self.border, self.resolution)
_components = pl_module.proto_layer._components _components = pl_module.proto_layer._components
mesh_input = torch.from_numpy(mesh_input).type_as(_components) mesh_input = torch.from_numpy(mesh_input).type_as(_components)
y_pred = pl_module.predict(mesh_input) y_pred = pl_module.predict(mesh_input)
@ -173,9 +165,9 @@ class VisSiameseGLVQ2D(Vis2DAbstract):
if self.show_protos: if self.show_protos:
self.plot_protos(ax, protos, plabels) self.plot_protos(ax, protos, plabels)
x = np.vstack((x_train, protos)) x = np.vstack((x_train, protos))
mesh_input, xx, yy = self.get_mesh_input(x) mesh_input, xx, yy = mesh2d(x, self.border, self.resolution)
else: else:
mesh_input, xx, yy = self.get_mesh_input(x_train) mesh_input, xx, yy = mesh2d(x_train, self.border, self.resolution)
_components = pl_module.proto_layer._components _components = pl_module.proto_layer._components
mesh_input = torch.Tensor(mesh_input).type_as(_components) mesh_input = torch.Tensor(mesh_input).type_as(_components)
y_pred = pl_module.predict_latent(mesh_input, y_pred = pl_module.predict_latent(mesh_input,
@ -198,7 +190,7 @@ class VisCBC2D(Vis2DAbstract):
self.plot_data(ax, x_train, y_train) self.plot_data(ax, x_train, y_train)
self.plot_protos(ax, protos, "w") self.plot_protos(ax, protos, "w")
x = np.vstack((x_train, protos)) x = np.vstack((x_train, protos))
mesh_input, xx, yy = self.get_mesh_input(x) mesh_input, xx, yy = mesh2d(x, self.border, self.resolution)
_components = pl_module.components_layer._components _components = pl_module.components_layer._components
y_pred = pl_module.predict( y_pred = pl_module.predict(
torch.Tensor(mesh_input).type_as(_components)) torch.Tensor(mesh_input).type_as(_components))