Use mesh2d
from prototorch.utils
This commit is contained in:
parent
4ab0a5a414
commit
7eb496110f
@ -7,6 +7,8 @@ import torchvision
|
||||
from matplotlib import pyplot as plt
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
|
||||
from ..utils.utils import mesh2d
|
||||
|
||||
|
||||
class Vis2DAbstract(pl.Callback):
|
||||
def __init__(self,
|
||||
@ -73,16 +75,6 @@ class Vis2DAbstract(pl.Callback):
|
||||
ax.axis("off")
|
||||
return ax
|
||||
|
||||
def get_mesh_input(self, x):
|
||||
x_shift = self.border * np.ptp(x[:, 0])
|
||||
y_shift = self.border * np.ptp(x[:, 1])
|
||||
x_min, x_max = x[:, 0].min() - x_shift, x[:, 0].max() + x_shift
|
||||
y_min, y_max = x[:, 1].min() - y_shift, x[:, 1].max() + y_shift
|
||||
xx, yy = np.meshgrid(np.linspace(x_min, x_max, self.resolution),
|
||||
np.linspace(y_min, y_max, self.resolution))
|
||||
mesh_input = np.c_[xx.ravel(), yy.ravel()]
|
||||
return mesh_input, xx, yy
|
||||
|
||||
def plot_data(self, ax, x, y):
|
||||
ax.scatter(
|
||||
x[:, 0],
|
||||
@ -138,7 +130,7 @@ class VisGLVQ2D(Vis2DAbstract):
|
||||
self.plot_data(ax, x_train, y_train)
|
||||
self.plot_protos(ax, protos, plabels)
|
||||
x = np.vstack((x_train, protos))
|
||||
mesh_input, xx, yy = self.get_mesh_input(x)
|
||||
mesh_input, xx, yy = mesh2d(x, self.border, self.resolution)
|
||||
_components = pl_module.proto_layer._components
|
||||
mesh_input = torch.from_numpy(mesh_input).type_as(_components)
|
||||
y_pred = pl_module.predict(mesh_input)
|
||||
@ -173,9 +165,9 @@ class VisSiameseGLVQ2D(Vis2DAbstract):
|
||||
if self.show_protos:
|
||||
self.plot_protos(ax, protos, plabels)
|
||||
x = np.vstack((x_train, protos))
|
||||
mesh_input, xx, yy = self.get_mesh_input(x)
|
||||
mesh_input, xx, yy = mesh2d(x, self.border, self.resolution)
|
||||
else:
|
||||
mesh_input, xx, yy = self.get_mesh_input(x_train)
|
||||
mesh_input, xx, yy = mesh2d(x_train, self.border, self.resolution)
|
||||
_components = pl_module.proto_layer._components
|
||||
mesh_input = torch.Tensor(mesh_input).type_as(_components)
|
||||
y_pred = pl_module.predict_latent(mesh_input,
|
||||
@ -198,7 +190,7 @@ class VisCBC2D(Vis2DAbstract):
|
||||
self.plot_data(ax, x_train, y_train)
|
||||
self.plot_protos(ax, protos, "w")
|
||||
x = np.vstack((x_train, protos))
|
||||
mesh_input, xx, yy = self.get_mesh_input(x)
|
||||
mesh_input, xx, yy = mesh2d(x, self.border, self.resolution)
|
||||
_components = pl_module.components_layer._components
|
||||
y_pred = pl_module.predict(
|
||||
torch.Tensor(mesh_input).type_as(_components))
|
||||
|
Loading…
Reference in New Issue
Block a user