Examples use GPUs if available.

This commit is contained in:
Alexander Engelsberger 2021-05-13 15:22:01 +02:00
parent 8f9c29bd2b
commit 0eac2ce326
14 changed files with 56 additions and 39 deletions

View File

@ -37,6 +37,7 @@ if __name__ == "__main__":
# Setup trainer # Setup trainer
trainer = pl.Trainer( trainer = pl.Trainer(
gpus=-1,
max_epochs=200, max_epochs=200,
callbacks=[ callbacks=[
dvis, dvis,

View File

@ -30,10 +30,11 @@ if __name__ == "__main__":
prototype_initializer=pt.components.SMI(train_ds)) prototype_initializer=pt.components.SMI(train_ds))
# Callbacks # Callbacks
vis = pt.models.VisGLVQ2D(data=(x_train, y_train)) vis = pt.models.VisGLVQ2D(data=(x_train, y_train), block=False)
# Setup trainer # Setup trainer
trainer = pl.Trainer( trainer = pl.Trainer(
gpus=-1,
max_epochs=50, max_epochs=50,
callbacks=[vis], callbacks=[vis],
) )

View File

@ -3,17 +3,7 @@
import prototorch as pt import prototorch as pt
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
from prototorch.models.callbacks import StopOnNaN
class StopOnNaN(pl.Callback):
def __init__(self, param):
super().__init__()
self.param = param
def on_epoch_end(self, trainer, pl_module, logs={}):
if torch.isnan(self.param).any():
raise ValueError("NaN encountered. Stopping.")
if __name__ == "__main__": if __name__ == "__main__":
# Dataset # Dataset
@ -40,11 +30,12 @@ if __name__ == "__main__":
noise=1e-1)) noise=1e-1))
# Callbacks # Callbacks
vis = pt.models.VisGLVQ2D(train_ds, show_last_only=True, block=True) vis = pt.models.VisGLVQ2D(train_ds, show_last_only=False, block=True)
snan = StopOnNaN(model.proto_layer.components) snan = StopOnNaN(model.proto_layer.components)
# Setup trainer # Setup trainer
trainer = pl.Trainer( trainer = pl.Trainer(
gpus=-1,
max_epochs=200, max_epochs=200,
callbacks=[vis, snan], callbacks=[vis, snan],
) )

View File

@ -29,7 +29,7 @@ if __name__ == "__main__":
prototype_initializer=pt.components.SMI(train_ds)) prototype_initializer=pt.components.SMI(train_ds))
# Setup trainer # Setup trainer
trainer = pl.Trainer(max_epochs=100) trainer = pl.Trainer(max_epochs=100, gpus=-1)
# Training loop # Training loop
trainer.fit(model, train_loader) trainer.fit(model, train_loader)

View File

@ -53,13 +53,15 @@ if __name__ == "__main__":
# Callbacks # Callbacks
vis = pt.models.VisImgComp(data=train_ds, vis = pt.models.VisImgComp(data=train_ds,
nrow=5, nrow=5,
show=False, show=True,
tensorboard=True) tensorboard=True,
pause_time=0.5)
# Setup trainer # Setup trainer
trainer = pl.Trainer( trainer = pl.Trainer(
max_epochs=50, max_epochs=50,
callbacks=[vis], callbacks=[vis],
gpus=-1,
# overfit_batches=1, # overfit_batches=1,
# fast_dev_run=3, # fast_dev_run=3,
) )

View File

@ -26,7 +26,7 @@ if __name__ == "__main__":
vis = pt.models.VisGLVQ2D(data=(x_train, y_train)) vis = pt.models.VisGLVQ2D(data=(x_train, y_train))
# Setup trainer # Setup trainer
trainer = pl.Trainer(max_epochs=1, callbacks=[vis]) trainer = pl.Trainer(max_epochs=1, callbacks=[vis], gpus=-1)
# Training loop # Training loop
# This is only for visualization. k-NN has no training phase. # This is only for visualization. k-NN has no training phase.

View File

@ -34,7 +34,7 @@ if __name__ == "__main__":
vis = pt.models.VisSiameseGLVQ2D(train_ds, border=0.1) vis = pt.models.VisSiameseGLVQ2D(train_ds, border=0.1)
# Setup trainer # Setup trainer
trainer = pl.Trainer(max_epochs=200, callbacks=[vis]) trainer = pl.Trainer(max_epochs=200, callbacks=[vis], gpus=-1)
# Training loop # Training loop
trainer.fit(model, train_loader) trainer.fit(model, train_loader)

View File

@ -34,7 +34,7 @@ if __name__ == "__main__":
vis = pt.models.VisNG2D(data=train_ds) vis = pt.models.VisNG2D(data=train_ds)
# Setup trainer # Setup trainer
trainer = pl.Trainer(max_epochs=200, callbacks=[vis]) trainer = pl.Trainer(gpus=-1, max_epochs=200, callbacks=[vis])
# Training loop # Training loop
trainer.fit(model, train_loader) trainer.fit(model, train_loader)

View File

@ -55,7 +55,7 @@ if __name__ == "__main__":
vis = pt.models.VisSiameseGLVQ2D(data=train_ds, border=0.1) vis = pt.models.VisSiameseGLVQ2D(data=train_ds, border=0.1)
# Setup trainer # Setup trainer
trainer = pl.Trainer(max_epochs=100, callbacks=[vis]) trainer = pl.Trainer(max_epochs=100, callbacks=[vis], gpus=-1)
# Training loop # Training loop
trainer.fit(model, train_loader) trainer.fit(model, train_loader)

View File

@ -0,0 +1,14 @@
"""Callbacks for Pytorch Lighning Modules"""
import pytorch_lightning as pl
import torch
class StopOnNaN(pl.Callback):
def __init__(self, param):
super().__init__()
self.param = param
def on_epoch_end(self, trainer, pl_module, logs={}):
if torch.isnan(self.param).any():
raise ValueError("NaN encountered. Stopping.")

View File

@ -153,7 +153,7 @@ class CBC(pl.LightningModule):
with torch.no_grad(): with torch.no_grad():
y_pred = self(x) y_pred = self(x)
y_pred = torch.argmax(y_pred, dim=1) y_pred = torch.argmax(y_pred, dim=1)
return y_pred.numpy() return y_pred
class ImageCBC(CBC): class ImageCBC(CBC):

View File

@ -84,7 +84,7 @@ class GLVQ(AbstractPrototypeModel):
d = self(x) d = self(x)
plabels = self.proto_layer.component_labels plabels = self.proto_layer.component_labels
y_pred = wtac(d, plabels) y_pred = wtac(d, plabels)
return y_pred.numpy() return y_pred
class LVQ1(GLVQ): class LVQ1(GLVQ):

View File

@ -33,7 +33,7 @@ class KNN(AbstractPrototypeModel):
@property @property
def prototype_labels(self): def prototype_labels(self):
return self.proto_layer.component_labels.detach().cpu() return self.proto_layer.component_labels.detach()
def forward(self, x): def forward(self, x):
protos, _ = self.proto_layer() protos, _ = self.proto_layer()
@ -46,7 +46,7 @@ class KNN(AbstractPrototypeModel):
d = self(x) d = self(x)
plabels = self.proto_layer.component_labels plabels = self.proto_layer.component_labels
y_pred = knnc(d, plabels, k=self.hparams.k) y_pred = knnc(d, plabels, k=self.hparams.k)
return y_pred.numpy() return y_pred
def training_step(self, train_batch, batch_idx, optimizer_idx=None): def training_step(self, train_batch, batch_idx, optimizer_idx=None):
return 1 return 1

View File

@ -379,8 +379,10 @@ class VisGLVQ2D(Vis2DAbstract):
self.plot_protos(ax, protos, plabels) self.plot_protos(ax, protos, plabels)
x = np.vstack((x_train, protos)) x = np.vstack((x_train, protos))
mesh_input, xx, yy = self.get_mesh_input(x) mesh_input, xx, yy = self.get_mesh_input(x)
y_pred = pl_module.predict(torch.Tensor(mesh_input)) _components = pl_module.proto_layer._components
y_pred = y_pred.reshape(xx.shape) y_pred = pl_module.predict(
torch.Tensor(mesh_input).type_as(_components))
y_pred = y_pred.cpu().reshape(xx.shape)
ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35) ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
self.log_and_display(trainer, pl_module) self.log_and_display(trainer, pl_module)
@ -398,20 +400,24 @@ class VisSiameseGLVQ2D(Vis2DAbstract):
protos = pl_module.prototypes protos = pl_module.prototypes
plabels = pl_module.prototype_labels plabels = pl_module.prototype_labels
x_train, y_train = self.x_train, self.y_train x_train, y_train = self.x_train, self.y_train
x_train = pl_module.backbone(torch.Tensor(x_train)).detach() x_train = pl_module.backbone(
torch.Tensor(x_train).to(pl_module.device)).cpu().detach()
if self.map_protos: if self.map_protos:
protos = pl_module.backbone(torch.Tensor(protos)).detach() protos = pl_module.backbone(
torch.Tensor(protos).to(pl_module.device)).cpu().detach()
ax = self.setup_ax() ax = self.setup_ax()
self.plot_data(ax, x_train, y_train) self.plot_data(ax, x_train, y_train)
if self.show_protos: #if self.show_protos:
self.plot_protos(ax, protos, plabels) # self.plot_protos(ax, protos, plabels)
x = np.vstack((x_train, protos)) # x = np.vstack((x_train, protos))
mesh_input, xx, yy = self.get_mesh_input(x) # mesh_input, xx, yy = self.get_mesh_input(x)
else: #else:
mesh_input, xx, yy = self.get_mesh_input(x_train) # mesh_input, xx, yy = self.get_mesh_input(x_train)
y_pred = pl_module.predict_latent(torch.Tensor(mesh_input)) #_components = pl_module.proto_layer._components
y_pred = y_pred.reshape(xx.shape) #y_pred = pl_module.predict(
ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35) # torch.Tensor(mesh_input).type_as(_components))
#y_pred = y_pred.cpu().reshape(xx.shape)
#ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
self.log_and_display(trainer, pl_module) self.log_and_display(trainer, pl_module)
@ -429,8 +435,10 @@ class VisCBC2D(Vis2DAbstract):
self.plot_protos(ax, protos, "w") self.plot_protos(ax, protos, "w")
x = np.vstack((x_train, protos)) x = np.vstack((x_train, protos))
mesh_input, xx, yy = self.get_mesh_input(x) mesh_input, xx, yy = self.get_mesh_input(x)
y_pred = pl_module.predict(torch.Tensor(mesh_input)) _components = pl_module.component_layer._components
y_pred = y_pred.reshape(xx.shape) y_pred = pl_module.predict(
torch.Tensor(mesh_input).type_as(_components))
y_pred = y_pred.cpu().reshape(xx.shape)
ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35) ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)