Examples use GPUs if available.
This commit is contained in:
@@ -37,6 +37,7 @@ if __name__ == "__main__":
|
||||
|
||||
# Setup trainer
|
||||
trainer = pl.Trainer(
|
||||
gpus=-1,
|
||||
max_epochs=200,
|
||||
callbacks=[
|
||||
dvis,
|
||||
|
@@ -30,10 +30,11 @@ if __name__ == "__main__":
|
||||
prototype_initializer=pt.components.SMI(train_ds))
|
||||
|
||||
# Callbacks
|
||||
vis = pt.models.VisGLVQ2D(data=(x_train, y_train))
|
||||
vis = pt.models.VisGLVQ2D(data=(x_train, y_train), block=False)
|
||||
|
||||
# Setup trainer
|
||||
trainer = pl.Trainer(
|
||||
gpus=-1,
|
||||
max_epochs=50,
|
||||
callbacks=[vis],
|
||||
)
|
||||
|
@@ -3,17 +3,7 @@
|
||||
import prototorch as pt
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
|
||||
|
||||
class StopOnNaN(pl.Callback):
|
||||
def __init__(self, param):
|
||||
super().__init__()
|
||||
self.param = param
|
||||
|
||||
def on_epoch_end(self, trainer, pl_module, logs={}):
|
||||
if torch.isnan(self.param).any():
|
||||
raise ValueError("NaN encountered. Stopping.")
|
||||
|
||||
from prototorch.models.callbacks import StopOnNaN
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Dataset
|
||||
@@ -40,11 +30,12 @@ if __name__ == "__main__":
|
||||
noise=1e-1))
|
||||
|
||||
# Callbacks
|
||||
vis = pt.models.VisGLVQ2D(train_ds, show_last_only=True, block=True)
|
||||
vis = pt.models.VisGLVQ2D(train_ds, show_last_only=False, block=True)
|
||||
snan = StopOnNaN(model.proto_layer.components)
|
||||
|
||||
# Setup trainer
|
||||
trainer = pl.Trainer(
|
||||
gpus=-1,
|
||||
max_epochs=200,
|
||||
callbacks=[vis, snan],
|
||||
)
|
||||
|
@@ -29,7 +29,7 @@ if __name__ == "__main__":
|
||||
prototype_initializer=pt.components.SMI(train_ds))
|
||||
|
||||
# Setup trainer
|
||||
trainer = pl.Trainer(max_epochs=100)
|
||||
trainer = pl.Trainer(max_epochs=100, gpus=-1)
|
||||
|
||||
# Training loop
|
||||
trainer.fit(model, train_loader)
|
||||
|
@@ -53,13 +53,15 @@ if __name__ == "__main__":
|
||||
# Callbacks
|
||||
vis = pt.models.VisImgComp(data=train_ds,
|
||||
nrow=5,
|
||||
show=False,
|
||||
tensorboard=True)
|
||||
show=True,
|
||||
tensorboard=True,
|
||||
pause_time=0.5)
|
||||
|
||||
# Setup trainer
|
||||
trainer = pl.Trainer(
|
||||
max_epochs=50,
|
||||
callbacks=[vis],
|
||||
gpus=-1,
|
||||
# overfit_batches=1,
|
||||
# fast_dev_run=3,
|
||||
)
|
||||
|
@@ -26,7 +26,7 @@ if __name__ == "__main__":
|
||||
vis = pt.models.VisGLVQ2D(data=(x_train, y_train))
|
||||
|
||||
# Setup trainer
|
||||
trainer = pl.Trainer(max_epochs=1, callbacks=[vis])
|
||||
trainer = pl.Trainer(max_epochs=1, callbacks=[vis], gpus=-1)
|
||||
|
||||
# Training loop
|
||||
# This is only for visualization. k-NN has no training phase.
|
||||
|
@@ -34,7 +34,7 @@ if __name__ == "__main__":
|
||||
vis = pt.models.VisSiameseGLVQ2D(train_ds, border=0.1)
|
||||
|
||||
# Setup trainer
|
||||
trainer = pl.Trainer(max_epochs=200, callbacks=[vis])
|
||||
trainer = pl.Trainer(max_epochs=200, callbacks=[vis], gpus=-1)
|
||||
|
||||
# Training loop
|
||||
trainer.fit(model, train_loader)
|
||||
|
@@ -34,7 +34,7 @@ if __name__ == "__main__":
|
||||
vis = pt.models.VisNG2D(data=train_ds)
|
||||
|
||||
# Setup trainer
|
||||
trainer = pl.Trainer(max_epochs=200, callbacks=[vis])
|
||||
trainer = pl.Trainer(gpus=-1, max_epochs=200, callbacks=[vis])
|
||||
|
||||
# Training loop
|
||||
trainer.fit(model, train_loader)
|
||||
|
@@ -55,7 +55,7 @@ if __name__ == "__main__":
|
||||
vis = pt.models.VisSiameseGLVQ2D(data=train_ds, border=0.1)
|
||||
|
||||
# Setup trainer
|
||||
trainer = pl.Trainer(max_epochs=100, callbacks=[vis])
|
||||
trainer = pl.Trainer(max_epochs=100, callbacks=[vis], gpus=-1)
|
||||
|
||||
# Training loop
|
||||
trainer.fit(model, train_loader)
|
||||
|
Reference in New Issue
Block a user