feat: data
argument optional in some visualizers
This commit is contained in:
parent
ce14dec7e9
commit
4941c2b89d
@ -14,7 +14,7 @@ from ..utils.utils import mesh2d
|
|||||||
class Vis2DAbstract(pl.Callback):
|
class Vis2DAbstract(pl.Callback):
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
data,
|
data=None,
|
||||||
title="Prototype Visualization",
|
title="Prototype Visualization",
|
||||||
cmap="viridis",
|
cmap="viridis",
|
||||||
xlabel="Data dimension 1",
|
xlabel="Data dimension 1",
|
||||||
@ -32,22 +32,26 @@ class Vis2DAbstract(pl.Callback):
|
|||||||
block=False):
|
block=False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
if isinstance(data, Dataset):
|
if data:
|
||||||
x, y = next(iter(DataLoader(data, batch_size=len(data))))
|
if isinstance(data, Dataset):
|
||||||
elif isinstance(data, torch.utils.data.DataLoader):
|
x, y = next(iter(DataLoader(data, batch_size=len(data))))
|
||||||
x = torch.tensor([])
|
elif isinstance(data, torch.utils.data.DataLoader):
|
||||||
y = torch.tensor([])
|
x = torch.tensor([])
|
||||||
for x_b, y_b in data:
|
y = torch.tensor([])
|
||||||
x = torch.cat([x, x_b])
|
for x_b, y_b in data:
|
||||||
y = torch.cat([y, y_b])
|
x = torch.cat([x, x_b])
|
||||||
|
y = torch.cat([y, y_b])
|
||||||
|
else:
|
||||||
|
x, y = data
|
||||||
|
|
||||||
|
if flatten_data:
|
||||||
|
x = x.reshape(len(x), -1)
|
||||||
|
|
||||||
|
self.x_train = x
|
||||||
|
self.y_train = y
|
||||||
else:
|
else:
|
||||||
x, y = data
|
self.x_train = None
|
||||||
|
self.y_train = None
|
||||||
if flatten_data:
|
|
||||||
x = x.reshape(len(x), -1)
|
|
||||||
|
|
||||||
self.x_train = x
|
|
||||||
self.y_train = y
|
|
||||||
|
|
||||||
self.title = title
|
self.title = title
|
||||||
self.xlabel = xlabel
|
self.xlabel = xlabel
|
||||||
@ -136,10 +140,13 @@ class VisGLVQ2D(Vis2DAbstract):
|
|||||||
plabels = pl_module.prototype_labels
|
plabels = pl_module.prototype_labels
|
||||||
x_train, y_train = self.x_train, self.y_train
|
x_train, y_train = self.x_train, self.y_train
|
||||||
ax = self.setup_ax()
|
ax = self.setup_ax()
|
||||||
self.plot_data(ax, x_train, y_train)
|
|
||||||
self.plot_protos(ax, protos, plabels)
|
self.plot_protos(ax, protos, plabels)
|
||||||
x = np.vstack((x_train, protos))
|
if x_train is not None:
|
||||||
mesh_input, xx, yy = mesh2d(x, self.border, self.resolution)
|
self.plot_data(ax, x_train, y_train)
|
||||||
|
mesh_input, xx, yy = mesh2d(np.vstack([x_train, protos]),
|
||||||
|
self.border, self.resolution)
|
||||||
|
else:
|
||||||
|
mesh_input, xx, yy = mesh2d(protos, self.border, self.resolution)
|
||||||
_components = pl_module.proto_layer._components
|
_components = pl_module.proto_layer._components
|
||||||
mesh_input = torch.from_numpy(mesh_input).type_as(_components)
|
mesh_input = torch.from_numpy(mesh_input).type_as(_components)
|
||||||
y_pred = pl_module.predict(mesh_input)
|
y_pred = pl_module.predict(mesh_input)
|
||||||
|
Loading…
Reference in New Issue
Block a user