From 4941c2b89dbe017174ad5170e868fbda5aa90f61 Mon Sep 17 00:00:00 2001 From: Jensun Ravichandran Date: Tue, 29 Mar 2022 11:26:22 +0200 Subject: [PATCH] feat: `data` argument optional in some visualizers --- prototorch/models/vis.py | 45 +++++++++++++++++++++++----------------- 1 file changed, 26 insertions(+), 19 deletions(-) diff --git a/prototorch/models/vis.py b/prototorch/models/vis.py index 2018ec9..66d19f4 100644 --- a/prototorch/models/vis.py +++ b/prototorch/models/vis.py @@ -14,7 +14,7 @@ from ..utils.utils import mesh2d class Vis2DAbstract(pl.Callback): def __init__(self, - data, + data=None, title="Prototype Visualization", cmap="viridis", xlabel="Data dimension 1", @@ -32,22 +32,26 @@ class Vis2DAbstract(pl.Callback): block=False): super().__init__() - if isinstance(data, Dataset): - x, y = next(iter(DataLoader(data, batch_size=len(data)))) - elif isinstance(data, torch.utils.data.DataLoader): - x = torch.tensor([]) - y = torch.tensor([]) - for x_b, y_b in data: - x = torch.cat([x, x_b]) - y = torch.cat([y, y_b]) + if data: + if isinstance(data, Dataset): + x, y = next(iter(DataLoader(data, batch_size=len(data)))) + elif isinstance(data, torch.utils.data.DataLoader): + x = torch.tensor([]) + y = torch.tensor([]) + for x_b, y_b in data: + x = torch.cat([x, x_b]) + y = torch.cat([y, y_b]) + else: + x, y = data + + if flatten_data: + x = x.reshape(len(x), -1) + + self.x_train = x + self.y_train = y else: - x, y = data - - if flatten_data: - x = x.reshape(len(x), -1) - - self.x_train = x - self.y_train = y + self.x_train = None + self.y_train = None self.title = title self.xlabel = xlabel @@ -136,10 +140,13 @@ class VisGLVQ2D(Vis2DAbstract): plabels = pl_module.prototype_labels x_train, y_train = self.x_train, self.y_train ax = self.setup_ax() - self.plot_data(ax, x_train, y_train) self.plot_protos(ax, protos, plabels) - x = np.vstack((x_train, protos)) - mesh_input, xx, yy = mesh2d(x, self.border, self.resolution) + if x_train is not None: + self.plot_data(ax, x_train, y_train) + mesh_input, xx, yy = mesh2d(np.vstack([x_train, protos]), + self.border, self.resolution) + else: + mesh_input, xx, yy = mesh2d(protos, self.border, self.resolution) _components = pl_module.proto_layer._components mesh_input = torch.from_numpy(mesh_input).type_as(_components) y_pred = pl_module.predict(mesh_input)