Refactor visualization callbacks

This commit is contained in:
Jensun Ravichandran 2021-05-09 20:53:03 +02:00
parent dd75fbfff8
commit ff7a1e93d2

View File

@ -269,6 +269,7 @@ class Vis2DAbstract(pl.Callback):
cmap="viridis", cmap="viridis",
border=1, border=1,
resolution=50, resolution=50,
show_protos=True,
tensorboard=False, tensorboard=False,
show_last_only=False, show_last_only=False,
pause_time=0.1, pause_time=0.1,
@ -288,11 +289,17 @@ class Vis2DAbstract(pl.Callback):
self.cmap = cmap self.cmap = cmap
self.border = border self.border = border
self.resolution = resolution self.resolution = resolution
self.show_protos = show_protos
self.tensorboard = tensorboard self.tensorboard = tensorboard
self.show_last_only = show_last_only self.show_last_only = show_last_only
self.pause_time = pause_time self.pause_time = pause_time
self.block = block self.block = block
def precheck(self, trainer):
if self.show_last_only:
if trainer.current_epoch != trainer.max_epochs - 1:
return
def setup_ax(self, xlabel=None, ylabel=None): def setup_ax(self, xlabel=None, ylabel=None):
ax = self.fig.gca() ax = self.fig.gca()
ax.cla() ax.cla()
@ -312,6 +319,28 @@ class Vis2DAbstract(pl.Callback):
mesh_input = np.c_[xx.ravel(), yy.ravel()] mesh_input = np.c_[xx.ravel(), yy.ravel()]
return mesh_input, xx, yy return mesh_input, xx, yy
def plot_data(self, ax, x, y):
ax.scatter(
x[:, 0],
x[:, 1],
c=y,
cmap=self.cmap,
edgecolor="k",
marker="o",
s=30,
)
def plot_protos(self, ax, protos, plabels):
ax.scatter(
protos[:, 0],
protos[:, 1],
c=plabels,
cmap=self.cmap,
edgecolor="k",
marker="D",
s=50,
)
def add_to_tensorboard(self, trainer, pl_module): def add_to_tensorboard(self, trainer, pl_module):
tb = pl_module.logger.experiment tb = pl_module.logger.experiment
tb.add_figure(tag=f"{self.title}", tb.add_figure(tag=f"{self.title}",
@ -330,115 +359,89 @@ class Vis2DAbstract(pl.Callback):
class VisGLVQ2D(Vis2DAbstract): class VisGLVQ2D(Vis2DAbstract):
def on_epoch_end(self, trainer, pl_module): def on_epoch_end(self, trainer, pl_module):
if self.show_last_only: self.precheck(trainer)
if trainer.current_epoch != trainer.max_epochs - 1:
return
protos = pl_module.prototypes protos = pl_module.prototypes
plabels = pl_module.prototype_labels plabels = pl_module.prototype_labels
x_train, y_train = self.x_train, self.y_train x_train, y_train = self.x_train, self.y_train
ax = self.setup_ax(xlabel="Data dimension 1", ax = self.setup_ax(xlabel="Data dimension 1",
ylabel="Data dimension 2") ylabel="Data dimension 2")
ax.scatter(x_train[:, 0], x_train[:, 1], c=y_train, edgecolor="k") self.plot_data(ax, x_train, y_train)
ax.scatter( self.plot_protos(ax, protos, plabels)
protos[:, 0],
protos[:, 1],
c=plabels,
cmap=self.cmap,
edgecolor="k",
marker="D",
s=50,
)
x = np.vstack((x_train, protos)) x = np.vstack((x_train, protos))
mesh_input, xx, yy = self.get_mesh_input(x) mesh_input, xx, yy = self.get_mesh_input(x)
y_pred = pl_module.predict(torch.Tensor(mesh_input)) y_pred = pl_module.predict(torch.Tensor(mesh_input))
y_pred = y_pred.reshape(xx.shape) y_pred = y_pred.reshape(xx.shape)
ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35) ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
# ax.set_xlim(left=x_min + 0, right=x_max - 0)
# ax.set_ylim(bottom=y_min + 0, top=y_max - 0)
self.log_and_display(trainer, pl_module) self.log_and_display(trainer, pl_module)
class VisSiameseGLVQ2D(Vis2DAbstract): class VisSiameseGLVQ2D(Vis2DAbstract):
def __init__(self, *args, map_protos=True, **kwargs):
super().__init__(*args, **kwargs)
self.map_protos = map_protos
def on_epoch_end(self, trainer, pl_module): def on_epoch_end(self, trainer, pl_module):
self.precheck(trainer)
protos = pl_module.prototypes protos = pl_module.prototypes
plabels = pl_module.prototype_labels plabels = pl_module.prototype_labels
x_train, y_train = self.x_train, self.y_train x_train, y_train = self.x_train, self.y_train
x_train = pl_module.backbone(torch.Tensor(x_train)).detach() x_train = pl_module.backbone(torch.Tensor(x_train)).detach()
protos = pl_module.backbone(torch.Tensor(protos)).detach() if self.map_protos:
protos = pl_module.backbone(torch.Tensor(protos)).detach()
ax = self.setup_ax() ax = self.setup_ax()
ax.scatter(x_train[:, 0], x_train[:, 1], c=y_train, edgecolor="k") self.plot_data(ax, x_train, y_train)
ax.scatter( if self.show_protos:
protos[:, 0], self.plot_protos(ax, protos, plabels)
protos[:, 1], x = np.vstack((x_train, protos))
c=plabels, mesh_input, xx, yy = self.get_mesh_input(x)
cmap=self.cmap, else:
edgecolor="k", mesh_input, xx, yy = self.get_mesh_input(x_train)
marker="D",
s=50,
)
x = np.vstack((x_train, protos))
mesh_input, xx, yy = self.get_mesh_input(x)
y_pred = pl_module.predict_latent(torch.Tensor(mesh_input)) y_pred = pl_module.predict_latent(torch.Tensor(mesh_input))
y_pred = y_pred.reshape(xx.shape) y_pred = y_pred.reshape(xx.shape)
ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35) ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
# ax.set_xlim(left=x_min + 0, right=x_max - 0)
# ax.set_ylim(bottom=y_min + 0, top=y_max - 0)
self.log_and_display(trainer, pl_module) self.log_and_display(trainer, pl_module)
class VisCBC2D(Vis2DAbstract): class VisCBC2D(Vis2DAbstract):
def on_epoch_end(self, trainer, pl_module): def on_epoch_end(self, trainer, pl_module):
self.precheck(trainer)
x_train, y_train = self.x_train, self.y_train x_train, y_train = self.x_train, self.y_train
protos = pl_module.components protos = pl_module.components
ax = self.setup_ax(xlabel="Data dimension 1", ax = self.setup_ax(xlabel="Data dimension 1",
ylabel="Data dimension 2") ylabel="Data dimension 2")
ax.scatter(x_train[:, 0], x_train[:, 1], c=y_train, edgecolor="k") self.plot_data(ax, x_train, y_train)
ax.scatter( self.plot_protos(ax, protos, plabels)
protos[:, 0],
protos[:, 1],
c="w",
cmap=self.cmap,
edgecolor="k",
marker="D",
s=50,
)
x = np.vstack((x_train, protos)) x = np.vstack((x_train, protos))
mesh_input, xx, yy = self.get_mesh_input(x) mesh_input, xx, yy = self.get_mesh_input(x)
y_pred = pl_module.predict(torch.Tensor(mesh_input)) y_pred = pl_module.predict(torch.Tensor(mesh_input))
y_pred = y_pred.reshape(xx.shape) y_pred = y_pred.reshape(xx.shape)
ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35) ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
# ax.set_xlim(left=x_min + 0, right=x_max - 0)
# ax.set_ylim(bottom=y_min + 0, top=y_max - 0)
self.log_and_display(trainer, pl_module) self.log_and_display(trainer, pl_module)
class VisNG2D(Vis2DAbstract): class VisNG2D(Vis2DAbstract):
def on_epoch_end(self, trainer, pl_module): def on_epoch_end(self, trainer, pl_module):
self.precheck(trainer)
x_train, y_train = self.x_train, self.y_train x_train, y_train = self.x_train, self.y_train
protos = pl_module.prototypes protos = pl_module.prototypes
cmat = pl_module.topology_layer.cmat.cpu().numpy() cmat = pl_module.topology_layer.cmat.cpu().numpy()
ax = self.setup_ax(xlabel="Data dimension 1", ax = self.setup_ax(xlabel="Data dimension 1",
ylabel="Data dimension 2") ylabel="Data dimension 2")
ax.scatter(x_train[:, 0], x_train[:, 1], c=y_train, edgecolor="k") self.plot_data(ax, x_train, y_train)
ax.scatter( self.plot_protos(ax, protos, "w")
protos[:, 0],
protos[:, 1],
c="k",
edgecolor="k",
marker="D",
s=50,
)
# Draw connections # Draw connections
for i in range(len(protos)): for i in range(len(protos)):
for j in range(len(protos)): for j in range(i, len(protos)):
if cmat[i][j]: if cmat[i][j]:
ax.plot( ax.plot(
[protos[i, 0], protos[j, 0]], [protos[i, 0], protos[j, 0]],