feat: data argument optional in some visualizers

This commit is contained in:
Jensun Ravichandran 2022-03-29 11:26:22 +02:00
parent ce14dec7e9
commit 4941c2b89d
No known key found for this signature in database
GPG Key ID: 4E9348239810B51F

View File

@ -14,7 +14,7 @@ from ..utils.utils import mesh2d
class Vis2DAbstract(pl.Callback): class Vis2DAbstract(pl.Callback):
def __init__(self, def __init__(self,
data, data=None,
title="Prototype Visualization", title="Prototype Visualization",
cmap="viridis", cmap="viridis",
xlabel="Data dimension 1", xlabel="Data dimension 1",
@ -32,22 +32,26 @@ class Vis2DAbstract(pl.Callback):
block=False): block=False):
super().__init__() super().__init__()
if isinstance(data, Dataset): if data:
x, y = next(iter(DataLoader(data, batch_size=len(data)))) if isinstance(data, Dataset):
elif isinstance(data, torch.utils.data.DataLoader): x, y = next(iter(DataLoader(data, batch_size=len(data))))
x = torch.tensor([]) elif isinstance(data, torch.utils.data.DataLoader):
y = torch.tensor([]) x = torch.tensor([])
for x_b, y_b in data: y = torch.tensor([])
x = torch.cat([x, x_b]) for x_b, y_b in data:
y = torch.cat([y, y_b]) x = torch.cat([x, x_b])
y = torch.cat([y, y_b])
else:
x, y = data
if flatten_data:
x = x.reshape(len(x), -1)
self.x_train = x
self.y_train = y
else: else:
x, y = data self.x_train = None
self.y_train = None
if flatten_data:
x = x.reshape(len(x), -1)
self.x_train = x
self.y_train = y
self.title = title self.title = title
self.xlabel = xlabel self.xlabel = xlabel
@ -136,10 +140,13 @@ class VisGLVQ2D(Vis2DAbstract):
plabels = pl_module.prototype_labels plabels = pl_module.prototype_labels
x_train, y_train = self.x_train, self.y_train x_train, y_train = self.x_train, self.y_train
ax = self.setup_ax() ax = self.setup_ax()
self.plot_data(ax, x_train, y_train)
self.plot_protos(ax, protos, plabels) self.plot_protos(ax, protos, plabels)
x = np.vstack((x_train, protos)) if x_train is not None:
mesh_input, xx, yy = mesh2d(x, self.border, self.resolution) self.plot_data(ax, x_train, y_train)
mesh_input, xx, yy = mesh2d(np.vstack([x_train, protos]),
self.border, self.resolution)
else:
mesh_input, xx, yy = mesh2d(protos, self.border, self.resolution)
_components = pl_module.proto_layer._components _components = pl_module.proto_layer._components
mesh_input = torch.from_numpy(mesh_input).type_as(_components) mesh_input = torch.from_numpy(mesh_input).type_as(_components)
y_pred = pl_module.predict(mesh_input) y_pred = pl_module.predict(mesh_input)