Use mesh2d
from prototorch.utils
This commit is contained in:
parent
4ab0a5a414
commit
7eb496110f
@ -7,6 +7,8 @@ import torchvision
|
|||||||
from matplotlib import pyplot as plt
|
from matplotlib import pyplot as plt
|
||||||
from torch.utils.data import DataLoader, Dataset
|
from torch.utils.data import DataLoader, Dataset
|
||||||
|
|
||||||
|
from ..utils.utils import mesh2d
|
||||||
|
|
||||||
|
|
||||||
class Vis2DAbstract(pl.Callback):
|
class Vis2DAbstract(pl.Callback):
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
@ -73,16 +75,6 @@ class Vis2DAbstract(pl.Callback):
|
|||||||
ax.axis("off")
|
ax.axis("off")
|
||||||
return ax
|
return ax
|
||||||
|
|
||||||
def get_mesh_input(self, x):
|
|
||||||
x_shift = self.border * np.ptp(x[:, 0])
|
|
||||||
y_shift = self.border * np.ptp(x[:, 1])
|
|
||||||
x_min, x_max = x[:, 0].min() - x_shift, x[:, 0].max() + x_shift
|
|
||||||
y_min, y_max = x[:, 1].min() - y_shift, x[:, 1].max() + y_shift
|
|
||||||
xx, yy = np.meshgrid(np.linspace(x_min, x_max, self.resolution),
|
|
||||||
np.linspace(y_min, y_max, self.resolution))
|
|
||||||
mesh_input = np.c_[xx.ravel(), yy.ravel()]
|
|
||||||
return mesh_input, xx, yy
|
|
||||||
|
|
||||||
def plot_data(self, ax, x, y):
|
def plot_data(self, ax, x, y):
|
||||||
ax.scatter(
|
ax.scatter(
|
||||||
x[:, 0],
|
x[:, 0],
|
||||||
@ -138,7 +130,7 @@ class VisGLVQ2D(Vis2DAbstract):
|
|||||||
self.plot_data(ax, x_train, y_train)
|
self.plot_data(ax, x_train, y_train)
|
||||||
self.plot_protos(ax, protos, plabels)
|
self.plot_protos(ax, protos, plabels)
|
||||||
x = np.vstack((x_train, protos))
|
x = np.vstack((x_train, protos))
|
||||||
mesh_input, xx, yy = self.get_mesh_input(x)
|
mesh_input, xx, yy = mesh2d(x, self.border, self.resolution)
|
||||||
_components = pl_module.proto_layer._components
|
_components = pl_module.proto_layer._components
|
||||||
mesh_input = torch.from_numpy(mesh_input).type_as(_components)
|
mesh_input = torch.from_numpy(mesh_input).type_as(_components)
|
||||||
y_pred = pl_module.predict(mesh_input)
|
y_pred = pl_module.predict(mesh_input)
|
||||||
@ -173,9 +165,9 @@ class VisSiameseGLVQ2D(Vis2DAbstract):
|
|||||||
if self.show_protos:
|
if self.show_protos:
|
||||||
self.plot_protos(ax, protos, plabels)
|
self.plot_protos(ax, protos, plabels)
|
||||||
x = np.vstack((x_train, protos))
|
x = np.vstack((x_train, protos))
|
||||||
mesh_input, xx, yy = self.get_mesh_input(x)
|
mesh_input, xx, yy = mesh2d(x, self.border, self.resolution)
|
||||||
else:
|
else:
|
||||||
mesh_input, xx, yy = self.get_mesh_input(x_train)
|
mesh_input, xx, yy = mesh2d(x_train, self.border, self.resolution)
|
||||||
_components = pl_module.proto_layer._components
|
_components = pl_module.proto_layer._components
|
||||||
mesh_input = torch.Tensor(mesh_input).type_as(_components)
|
mesh_input = torch.Tensor(mesh_input).type_as(_components)
|
||||||
y_pred = pl_module.predict_latent(mesh_input,
|
y_pred = pl_module.predict_latent(mesh_input,
|
||||||
@ -198,7 +190,7 @@ class VisCBC2D(Vis2DAbstract):
|
|||||||
self.plot_data(ax, x_train, y_train)
|
self.plot_data(ax, x_train, y_train)
|
||||||
self.plot_protos(ax, protos, "w")
|
self.plot_protos(ax, protos, "w")
|
||||||
x = np.vstack((x_train, protos))
|
x = np.vstack((x_train, protos))
|
||||||
mesh_input, xx, yy = self.get_mesh_input(x)
|
mesh_input, xx, yy = mesh2d(x, self.border, self.resolution)
|
||||||
_components = pl_module.components_layer._components
|
_components = pl_module.components_layer._components
|
||||||
y_pred = pl_module.predict(
|
y_pred = pl_module.predict(
|
||||||
torch.Tensor(mesh_input).type_as(_components))
|
torch.Tensor(mesh_input).type_as(_components))
|
||||||
|
Loading…
Reference in New Issue
Block a user