Examples use GPUs if available.

This commit is contained in:
Alexander Engelsberger
2021-05-13 15:22:01 +02:00
parent 8f9c29bd2b
commit 0eac2ce326
14 changed files with 56 additions and 39 deletions

View File

@@ -37,6 +37,7 @@ if __name__ == "__main__":
# Setup trainer
trainer = pl.Trainer(
gpus=-1,
max_epochs=200,
callbacks=[
dvis,

View File

@@ -30,10 +30,11 @@ if __name__ == "__main__":
prototype_initializer=pt.components.SMI(train_ds))
# Callbacks
vis = pt.models.VisGLVQ2D(data=(x_train, y_train))
vis = pt.models.VisGLVQ2D(data=(x_train, y_train), block=False)
# Setup trainer
trainer = pl.Trainer(
gpus=-1,
max_epochs=50,
callbacks=[vis],
)

View File

@@ -3,17 +3,7 @@
import prototorch as pt
import pytorch_lightning as pl
import torch
class StopOnNaN(pl.Callback):
def __init__(self, param):
super().__init__()
self.param = param
def on_epoch_end(self, trainer, pl_module, logs={}):
if torch.isnan(self.param).any():
raise ValueError("NaN encountered. Stopping.")
from prototorch.models.callbacks import StopOnNaN
if __name__ == "__main__":
# Dataset
@@ -40,11 +30,12 @@ if __name__ == "__main__":
noise=1e-1))
# Callbacks
vis = pt.models.VisGLVQ2D(train_ds, show_last_only=True, block=True)
vis = pt.models.VisGLVQ2D(train_ds, show_last_only=False, block=True)
snan = StopOnNaN(model.proto_layer.components)
# Setup trainer
trainer = pl.Trainer(
gpus=-1,
max_epochs=200,
callbacks=[vis, snan],
)

View File

@@ -29,7 +29,7 @@ if __name__ == "__main__":
prototype_initializer=pt.components.SMI(train_ds))
# Setup trainer
trainer = pl.Trainer(max_epochs=100)
trainer = pl.Trainer(max_epochs=100, gpus=-1)
# Training loop
trainer.fit(model, train_loader)

View File

@@ -53,13 +53,15 @@ if __name__ == "__main__":
# Callbacks
vis = pt.models.VisImgComp(data=train_ds,
nrow=5,
show=False,
tensorboard=True)
show=True,
tensorboard=True,
pause_time=0.5)
# Setup trainer
trainer = pl.Trainer(
max_epochs=50,
callbacks=[vis],
gpus=-1,
# overfit_batches=1,
# fast_dev_run=3,
)

View File

@@ -26,7 +26,7 @@ if __name__ == "__main__":
vis = pt.models.VisGLVQ2D(data=(x_train, y_train))
# Setup trainer
trainer = pl.Trainer(max_epochs=1, callbacks=[vis])
trainer = pl.Trainer(max_epochs=1, callbacks=[vis], gpus=-1)
# Training loop
# This is only for visualization. k-NN has no training phase.

View File

@@ -34,7 +34,7 @@ if __name__ == "__main__":
vis = pt.models.VisSiameseGLVQ2D(train_ds, border=0.1)
# Setup trainer
trainer = pl.Trainer(max_epochs=200, callbacks=[vis])
trainer = pl.Trainer(max_epochs=200, callbacks=[vis], gpus=-1)
# Training loop
trainer.fit(model, train_loader)

View File

@@ -34,7 +34,7 @@ if __name__ == "__main__":
vis = pt.models.VisNG2D(data=train_ds)
# Setup trainer
trainer = pl.Trainer(max_epochs=200, callbacks=[vis])
trainer = pl.Trainer(gpus=-1, max_epochs=200, callbacks=[vis])
# Training loop
trainer.fit(model, train_loader)

View File

@@ -55,7 +55,7 @@ if __name__ == "__main__":
vis = pt.models.VisSiameseGLVQ2D(data=train_ds, border=0.1)
# Setup trainer
trainer = pl.Trainer(max_epochs=100, callbacks=[vis])
trainer = pl.Trainer(max_epochs=100, callbacks=[vis], gpus=-1)
# Training loop
trainer.fit(model, train_loader)