feat: data argument optional in some visualizers

This commit is contained in:
Jensun Ravichandran 2022-03-29 11:26:22 +02:00
parent ce14dec7e9
commit 4941c2b89d
No known key found for this signature in database
GPG Key ID: 4E9348239810B51F

View File

@ -14,7 +14,7 @@ from ..utils.utils import mesh2d
class Vis2DAbstract(pl.Callback):
def __init__(self,
data,
data=None,
title="Prototype Visualization",
cmap="viridis",
xlabel="Data dimension 1",
@ -32,22 +32,26 @@ class Vis2DAbstract(pl.Callback):
block=False):
super().__init__()
if isinstance(data, Dataset):
x, y = next(iter(DataLoader(data, batch_size=len(data))))
elif isinstance(data, torch.utils.data.DataLoader):
x = torch.tensor([])
y = torch.tensor([])
for x_b, y_b in data:
x = torch.cat([x, x_b])
y = torch.cat([y, y_b])
if data:
if isinstance(data, Dataset):
x, y = next(iter(DataLoader(data, batch_size=len(data))))
elif isinstance(data, torch.utils.data.DataLoader):
x = torch.tensor([])
y = torch.tensor([])
for x_b, y_b in data:
x = torch.cat([x, x_b])
y = torch.cat([y, y_b])
else:
x, y = data
if flatten_data:
x = x.reshape(len(x), -1)
self.x_train = x
self.y_train = y
else:
x, y = data
if flatten_data:
x = x.reshape(len(x), -1)
self.x_train = x
self.y_train = y
self.x_train = None
self.y_train = None
self.title = title
self.xlabel = xlabel
@ -136,10 +140,13 @@ class VisGLVQ2D(Vis2DAbstract):
plabels = pl_module.prototype_labels
x_train, y_train = self.x_train, self.y_train
ax = self.setup_ax()
self.plot_data(ax, x_train, y_train)
self.plot_protos(ax, protos, plabels)
x = np.vstack((x_train, protos))
mesh_input, xx, yy = mesh2d(x, self.border, self.resolution)
if x_train is not None:
self.plot_data(ax, x_train, y_train)
mesh_input, xx, yy = mesh2d(np.vstack([x_train, protos]),
self.border, self.resolution)
else:
mesh_input, xx, yy = mesh2d(protos, self.border, self.resolution)
_components = pl_module.proto_layer._components
mesh_input = torch.from_numpy(mesh_input).type_as(_components)
y_pred = pl_module.predict(mesh_input)