Examples use GPUs if available.

This commit is contained in:
Alexander Engelsberger
2021-05-13 15:22:01 +02:00
parent 8f9c29bd2b
commit 0eac2ce326
14 changed files with 56 additions and 39 deletions

View File

@@ -379,8 +379,10 @@ class VisGLVQ2D(Vis2DAbstract):
self.plot_protos(ax, protos, plabels)
x = np.vstack((x_train, protos))
mesh_input, xx, yy = self.get_mesh_input(x)
y_pred = pl_module.predict(torch.Tensor(mesh_input))
y_pred = y_pred.reshape(xx.shape)
_components = pl_module.proto_layer._components
y_pred = pl_module.predict(
torch.Tensor(mesh_input).type_as(_components))
y_pred = y_pred.cpu().reshape(xx.shape)
ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
self.log_and_display(trainer, pl_module)
@@ -398,20 +400,24 @@ class VisSiameseGLVQ2D(Vis2DAbstract):
protos = pl_module.prototypes
plabels = pl_module.prototype_labels
x_train, y_train = self.x_train, self.y_train
x_train = pl_module.backbone(torch.Tensor(x_train)).detach()
x_train = pl_module.backbone(
torch.Tensor(x_train).to(pl_module.device)).cpu().detach()
if self.map_protos:
protos = pl_module.backbone(torch.Tensor(protos)).detach()
protos = pl_module.backbone(
torch.Tensor(protos).to(pl_module.device)).cpu().detach()
ax = self.setup_ax()
self.plot_data(ax, x_train, y_train)
if self.show_protos:
self.plot_protos(ax, protos, plabels)
x = np.vstack((x_train, protos))
mesh_input, xx, yy = self.get_mesh_input(x)
else:
mesh_input, xx, yy = self.get_mesh_input(x_train)
y_pred = pl_module.predict_latent(torch.Tensor(mesh_input))
y_pred = y_pred.reshape(xx.shape)
ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
#if self.show_protos:
# self.plot_protos(ax, protos, plabels)
# x = np.vstack((x_train, protos))
# mesh_input, xx, yy = self.get_mesh_input(x)
#else:
# mesh_input, xx, yy = self.get_mesh_input(x_train)
#_components = pl_module.proto_layer._components
#y_pred = pl_module.predict(
# torch.Tensor(mesh_input).type_as(_components))
#y_pred = y_pred.cpu().reshape(xx.shape)
#ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
self.log_and_display(trainer, pl_module)
@@ -429,8 +435,10 @@ class VisCBC2D(Vis2DAbstract):
self.plot_protos(ax, protos, "w")
x = np.vstack((x_train, protos))
mesh_input, xx, yy = self.get_mesh_input(x)
y_pred = pl_module.predict(torch.Tensor(mesh_input))
y_pred = y_pred.reshape(xx.shape)
_components = pl_module.component_layer._components
y_pred = pl_module.predict(
torch.Tensor(mesh_input).type_as(_components))
y_pred = y_pred.cpu().reshape(xx.shape)
ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)