Examples use GPUs if available.
This commit is contained in:
14
prototorch/models/callbacks.py
Normal file
14
prototorch/models/callbacks.py
Normal file
@@ -0,0 +1,14 @@
|
||||
"""Callbacks for Pytorch Lighning Modules"""
|
||||
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
|
||||
|
||||
class StopOnNaN(pl.Callback):
|
||||
def __init__(self, param):
|
||||
super().__init__()
|
||||
self.param = param
|
||||
|
||||
def on_epoch_end(self, trainer, pl_module, logs={}):
|
||||
if torch.isnan(self.param).any():
|
||||
raise ValueError("NaN encountered. Stopping.")
|
@@ -153,7 +153,7 @@ class CBC(pl.LightningModule):
|
||||
with torch.no_grad():
|
||||
y_pred = self(x)
|
||||
y_pred = torch.argmax(y_pred, dim=1)
|
||||
return y_pred.numpy()
|
||||
return y_pred
|
||||
|
||||
|
||||
class ImageCBC(CBC):
|
||||
|
@@ -84,7 +84,7 @@ class GLVQ(AbstractPrototypeModel):
|
||||
d = self(x)
|
||||
plabels = self.proto_layer.component_labels
|
||||
y_pred = wtac(d, plabels)
|
||||
return y_pred.numpy()
|
||||
return y_pred
|
||||
|
||||
|
||||
class LVQ1(GLVQ):
|
||||
|
@@ -33,7 +33,7 @@ class KNN(AbstractPrototypeModel):
|
||||
|
||||
@property
|
||||
def prototype_labels(self):
|
||||
return self.proto_layer.component_labels.detach().cpu()
|
||||
return self.proto_layer.component_labels.detach()
|
||||
|
||||
def forward(self, x):
|
||||
protos, _ = self.proto_layer()
|
||||
@@ -46,7 +46,7 @@ class KNN(AbstractPrototypeModel):
|
||||
d = self(x)
|
||||
plabels = self.proto_layer.component_labels
|
||||
y_pred = knnc(d, plabels, k=self.hparams.k)
|
||||
return y_pred.numpy()
|
||||
return y_pred
|
||||
|
||||
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
|
||||
return 1
|
||||
|
@@ -379,8 +379,10 @@ class VisGLVQ2D(Vis2DAbstract):
|
||||
self.plot_protos(ax, protos, plabels)
|
||||
x = np.vstack((x_train, protos))
|
||||
mesh_input, xx, yy = self.get_mesh_input(x)
|
||||
y_pred = pl_module.predict(torch.Tensor(mesh_input))
|
||||
y_pred = y_pred.reshape(xx.shape)
|
||||
_components = pl_module.proto_layer._components
|
||||
y_pred = pl_module.predict(
|
||||
torch.Tensor(mesh_input).type_as(_components))
|
||||
y_pred = y_pred.cpu().reshape(xx.shape)
|
||||
ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
|
||||
|
||||
self.log_and_display(trainer, pl_module)
|
||||
@@ -398,20 +400,24 @@ class VisSiameseGLVQ2D(Vis2DAbstract):
|
||||
protos = pl_module.prototypes
|
||||
plabels = pl_module.prototype_labels
|
||||
x_train, y_train = self.x_train, self.y_train
|
||||
x_train = pl_module.backbone(torch.Tensor(x_train)).detach()
|
||||
x_train = pl_module.backbone(
|
||||
torch.Tensor(x_train).to(pl_module.device)).cpu().detach()
|
||||
if self.map_protos:
|
||||
protos = pl_module.backbone(torch.Tensor(protos)).detach()
|
||||
protos = pl_module.backbone(
|
||||
torch.Tensor(protos).to(pl_module.device)).cpu().detach()
|
||||
ax = self.setup_ax()
|
||||
self.plot_data(ax, x_train, y_train)
|
||||
if self.show_protos:
|
||||
self.plot_protos(ax, protos, plabels)
|
||||
x = np.vstack((x_train, protos))
|
||||
mesh_input, xx, yy = self.get_mesh_input(x)
|
||||
else:
|
||||
mesh_input, xx, yy = self.get_mesh_input(x_train)
|
||||
y_pred = pl_module.predict_latent(torch.Tensor(mesh_input))
|
||||
y_pred = y_pred.reshape(xx.shape)
|
||||
ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
|
||||
#if self.show_protos:
|
||||
# self.plot_protos(ax, protos, plabels)
|
||||
# x = np.vstack((x_train, protos))
|
||||
# mesh_input, xx, yy = self.get_mesh_input(x)
|
||||
#else:
|
||||
# mesh_input, xx, yy = self.get_mesh_input(x_train)
|
||||
#_components = pl_module.proto_layer._components
|
||||
#y_pred = pl_module.predict(
|
||||
# torch.Tensor(mesh_input).type_as(_components))
|
||||
#y_pred = y_pred.cpu().reshape(xx.shape)
|
||||
#ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
|
||||
|
||||
self.log_and_display(trainer, pl_module)
|
||||
|
||||
@@ -429,8 +435,10 @@ class VisCBC2D(Vis2DAbstract):
|
||||
self.plot_protos(ax, protos, "w")
|
||||
x = np.vstack((x_train, protos))
|
||||
mesh_input, xx, yy = self.get_mesh_input(x)
|
||||
y_pred = pl_module.predict(torch.Tensor(mesh_input))
|
||||
y_pred = y_pred.reshape(xx.shape)
|
||||
_components = pl_module.component_layer._components
|
||||
y_pred = pl_module.predict(
|
||||
torch.Tensor(mesh_input).type_as(_components))
|
||||
y_pred = y_pred.cpu().reshape(xx.shape)
|
||||
|
||||
ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
|
||||
|
||||
|
Reference in New Issue
Block a user