feat: data argument optional in some visualizers

This commit is contained in:
Jensun Ravichandran 2022-03-29 11:26:22 +02:00
parent ce14dec7e9
commit 4941c2b89d
No known key found for this signature in database
GPG Key ID: 4E9348239810B51F

View File

@ -14,7 +14,7 @@ from ..utils.utils import mesh2d
class Vis2DAbstract(pl.Callback): class Vis2DAbstract(pl.Callback):
def __init__(self, def __init__(self,
data, data=None,
title="Prototype Visualization", title="Prototype Visualization",
cmap="viridis", cmap="viridis",
xlabel="Data dimension 1", xlabel="Data dimension 1",
@ -32,6 +32,7 @@ class Vis2DAbstract(pl.Callback):
block=False): block=False):
super().__init__() super().__init__()
if data:
if isinstance(data, Dataset): if isinstance(data, Dataset):
x, y = next(iter(DataLoader(data, batch_size=len(data)))) x, y = next(iter(DataLoader(data, batch_size=len(data))))
elif isinstance(data, torch.utils.data.DataLoader): elif isinstance(data, torch.utils.data.DataLoader):
@ -48,6 +49,9 @@ class Vis2DAbstract(pl.Callback):
self.x_train = x self.x_train = x
self.y_train = y self.y_train = y
else:
self.x_train = None
self.y_train = None
self.title = title self.title = title
self.xlabel = xlabel self.xlabel = xlabel
@ -136,10 +140,13 @@ class VisGLVQ2D(Vis2DAbstract):
plabels = pl_module.prototype_labels plabels = pl_module.prototype_labels
x_train, y_train = self.x_train, self.y_train x_train, y_train = self.x_train, self.y_train
ax = self.setup_ax() ax = self.setup_ax()
self.plot_data(ax, x_train, y_train)
self.plot_protos(ax, protos, plabels) self.plot_protos(ax, protos, plabels)
x = np.vstack((x_train, protos)) if x_train is not None:
mesh_input, xx, yy = mesh2d(x, self.border, self.resolution) self.plot_data(ax, x_train, y_train)
mesh_input, xx, yy = mesh2d(np.vstack([x_train, protos]),
self.border, self.resolution)
else:
mesh_input, xx, yy = mesh2d(protos, self.border, self.resolution)
_components = pl_module.proto_layer._components _components = pl_module.proto_layer._components
mesh_input = torch.from_numpy(mesh_input).type_as(_components) mesh_input = torch.from_numpy(mesh_input).type_as(_components)
y_pred = pl_module.predict(mesh_input) y_pred = pl_module.predict(mesh_input)