Examples use GPUs if available.
This commit is contained in:
parent
8f9c29bd2b
commit
0eac2ce326
@ -37,6 +37,7 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
# Setup trainer
|
# Setup trainer
|
||||||
trainer = pl.Trainer(
|
trainer = pl.Trainer(
|
||||||
|
gpus=-1,
|
||||||
max_epochs=200,
|
max_epochs=200,
|
||||||
callbacks=[
|
callbacks=[
|
||||||
dvis,
|
dvis,
|
||||||
|
@ -30,10 +30,11 @@ if __name__ == "__main__":
|
|||||||
prototype_initializer=pt.components.SMI(train_ds))
|
prototype_initializer=pt.components.SMI(train_ds))
|
||||||
|
|
||||||
# Callbacks
|
# Callbacks
|
||||||
vis = pt.models.VisGLVQ2D(data=(x_train, y_train))
|
vis = pt.models.VisGLVQ2D(data=(x_train, y_train), block=False)
|
||||||
|
|
||||||
# Setup trainer
|
# Setup trainer
|
||||||
trainer = pl.Trainer(
|
trainer = pl.Trainer(
|
||||||
|
gpus=-1,
|
||||||
max_epochs=50,
|
max_epochs=50,
|
||||||
callbacks=[vis],
|
callbacks=[vis],
|
||||||
)
|
)
|
||||||
|
@ -3,17 +3,7 @@
|
|||||||
import prototorch as pt
|
import prototorch as pt
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import torch
|
import torch
|
||||||
|
from prototorch.models.callbacks import StopOnNaN
|
||||||
|
|
||||||
class StopOnNaN(pl.Callback):
|
|
||||||
def __init__(self, param):
|
|
||||||
super().__init__()
|
|
||||||
self.param = param
|
|
||||||
|
|
||||||
def on_epoch_end(self, trainer, pl_module, logs={}):
|
|
||||||
if torch.isnan(self.param).any():
|
|
||||||
raise ValueError("NaN encountered. Stopping.")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# Dataset
|
# Dataset
|
||||||
@ -40,11 +30,12 @@ if __name__ == "__main__":
|
|||||||
noise=1e-1))
|
noise=1e-1))
|
||||||
|
|
||||||
# Callbacks
|
# Callbacks
|
||||||
vis = pt.models.VisGLVQ2D(train_ds, show_last_only=True, block=True)
|
vis = pt.models.VisGLVQ2D(train_ds, show_last_only=False, block=True)
|
||||||
snan = StopOnNaN(model.proto_layer.components)
|
snan = StopOnNaN(model.proto_layer.components)
|
||||||
|
|
||||||
# Setup trainer
|
# Setup trainer
|
||||||
trainer = pl.Trainer(
|
trainer = pl.Trainer(
|
||||||
|
gpus=-1,
|
||||||
max_epochs=200,
|
max_epochs=200,
|
||||||
callbacks=[vis, snan],
|
callbacks=[vis, snan],
|
||||||
)
|
)
|
||||||
|
@ -29,7 +29,7 @@ if __name__ == "__main__":
|
|||||||
prototype_initializer=pt.components.SMI(train_ds))
|
prototype_initializer=pt.components.SMI(train_ds))
|
||||||
|
|
||||||
# Setup trainer
|
# Setup trainer
|
||||||
trainer = pl.Trainer(max_epochs=100)
|
trainer = pl.Trainer(max_epochs=100, gpus=-1)
|
||||||
|
|
||||||
# Training loop
|
# Training loop
|
||||||
trainer.fit(model, train_loader)
|
trainer.fit(model, train_loader)
|
||||||
|
@ -53,13 +53,15 @@ if __name__ == "__main__":
|
|||||||
# Callbacks
|
# Callbacks
|
||||||
vis = pt.models.VisImgComp(data=train_ds,
|
vis = pt.models.VisImgComp(data=train_ds,
|
||||||
nrow=5,
|
nrow=5,
|
||||||
show=False,
|
show=True,
|
||||||
tensorboard=True)
|
tensorboard=True,
|
||||||
|
pause_time=0.5)
|
||||||
|
|
||||||
# Setup trainer
|
# Setup trainer
|
||||||
trainer = pl.Trainer(
|
trainer = pl.Trainer(
|
||||||
max_epochs=50,
|
max_epochs=50,
|
||||||
callbacks=[vis],
|
callbacks=[vis],
|
||||||
|
gpus=-1,
|
||||||
# overfit_batches=1,
|
# overfit_batches=1,
|
||||||
# fast_dev_run=3,
|
# fast_dev_run=3,
|
||||||
)
|
)
|
||||||
|
@ -26,7 +26,7 @@ if __name__ == "__main__":
|
|||||||
vis = pt.models.VisGLVQ2D(data=(x_train, y_train))
|
vis = pt.models.VisGLVQ2D(data=(x_train, y_train))
|
||||||
|
|
||||||
# Setup trainer
|
# Setup trainer
|
||||||
trainer = pl.Trainer(max_epochs=1, callbacks=[vis])
|
trainer = pl.Trainer(max_epochs=1, callbacks=[vis], gpus=-1)
|
||||||
|
|
||||||
# Training loop
|
# Training loop
|
||||||
# This is only for visualization. k-NN has no training phase.
|
# This is only for visualization. k-NN has no training phase.
|
||||||
|
@ -34,7 +34,7 @@ if __name__ == "__main__":
|
|||||||
vis = pt.models.VisSiameseGLVQ2D(train_ds, border=0.1)
|
vis = pt.models.VisSiameseGLVQ2D(train_ds, border=0.1)
|
||||||
|
|
||||||
# Setup trainer
|
# Setup trainer
|
||||||
trainer = pl.Trainer(max_epochs=200, callbacks=[vis])
|
trainer = pl.Trainer(max_epochs=200, callbacks=[vis], gpus=-1)
|
||||||
|
|
||||||
# Training loop
|
# Training loop
|
||||||
trainer.fit(model, train_loader)
|
trainer.fit(model, train_loader)
|
||||||
|
@ -34,7 +34,7 @@ if __name__ == "__main__":
|
|||||||
vis = pt.models.VisNG2D(data=train_ds)
|
vis = pt.models.VisNG2D(data=train_ds)
|
||||||
|
|
||||||
# Setup trainer
|
# Setup trainer
|
||||||
trainer = pl.Trainer(max_epochs=200, callbacks=[vis])
|
trainer = pl.Trainer(gpus=-1, max_epochs=200, callbacks=[vis])
|
||||||
|
|
||||||
# Training loop
|
# Training loop
|
||||||
trainer.fit(model, train_loader)
|
trainer.fit(model, train_loader)
|
||||||
|
@ -55,7 +55,7 @@ if __name__ == "__main__":
|
|||||||
vis = pt.models.VisSiameseGLVQ2D(data=train_ds, border=0.1)
|
vis = pt.models.VisSiameseGLVQ2D(data=train_ds, border=0.1)
|
||||||
|
|
||||||
# Setup trainer
|
# Setup trainer
|
||||||
trainer = pl.Trainer(max_epochs=100, callbacks=[vis])
|
trainer = pl.Trainer(max_epochs=100, callbacks=[vis], gpus=-1)
|
||||||
|
|
||||||
# Training loop
|
# Training loop
|
||||||
trainer.fit(model, train_loader)
|
trainer.fit(model, train_loader)
|
||||||
|
14
prototorch/models/callbacks.py
Normal file
14
prototorch/models/callbacks.py
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
"""Callbacks for Pytorch Lighning Modules"""
|
||||||
|
|
||||||
|
import pytorch_lightning as pl
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class StopOnNaN(pl.Callback):
|
||||||
|
def __init__(self, param):
|
||||||
|
super().__init__()
|
||||||
|
self.param = param
|
||||||
|
|
||||||
|
def on_epoch_end(self, trainer, pl_module, logs={}):
|
||||||
|
if torch.isnan(self.param).any():
|
||||||
|
raise ValueError("NaN encountered. Stopping.")
|
@ -153,7 +153,7 @@ class CBC(pl.LightningModule):
|
|||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
y_pred = self(x)
|
y_pred = self(x)
|
||||||
y_pred = torch.argmax(y_pred, dim=1)
|
y_pred = torch.argmax(y_pred, dim=1)
|
||||||
return y_pred.numpy()
|
return y_pred
|
||||||
|
|
||||||
|
|
||||||
class ImageCBC(CBC):
|
class ImageCBC(CBC):
|
||||||
|
@ -84,7 +84,7 @@ class GLVQ(AbstractPrototypeModel):
|
|||||||
d = self(x)
|
d = self(x)
|
||||||
plabels = self.proto_layer.component_labels
|
plabels = self.proto_layer.component_labels
|
||||||
y_pred = wtac(d, plabels)
|
y_pred = wtac(d, plabels)
|
||||||
return y_pred.numpy()
|
return y_pred
|
||||||
|
|
||||||
|
|
||||||
class LVQ1(GLVQ):
|
class LVQ1(GLVQ):
|
||||||
|
@ -33,7 +33,7 @@ class KNN(AbstractPrototypeModel):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def prototype_labels(self):
|
def prototype_labels(self):
|
||||||
return self.proto_layer.component_labels.detach().cpu()
|
return self.proto_layer.component_labels.detach()
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
protos, _ = self.proto_layer()
|
protos, _ = self.proto_layer()
|
||||||
@ -46,7 +46,7 @@ class KNN(AbstractPrototypeModel):
|
|||||||
d = self(x)
|
d = self(x)
|
||||||
plabels = self.proto_layer.component_labels
|
plabels = self.proto_layer.component_labels
|
||||||
y_pred = knnc(d, plabels, k=self.hparams.k)
|
y_pred = knnc(d, plabels, k=self.hparams.k)
|
||||||
return y_pred.numpy()
|
return y_pred
|
||||||
|
|
||||||
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
|
def training_step(self, train_batch, batch_idx, optimizer_idx=None):
|
||||||
return 1
|
return 1
|
||||||
|
@ -379,8 +379,10 @@ class VisGLVQ2D(Vis2DAbstract):
|
|||||||
self.plot_protos(ax, protos, plabels)
|
self.plot_protos(ax, protos, plabels)
|
||||||
x = np.vstack((x_train, protos))
|
x = np.vstack((x_train, protos))
|
||||||
mesh_input, xx, yy = self.get_mesh_input(x)
|
mesh_input, xx, yy = self.get_mesh_input(x)
|
||||||
y_pred = pl_module.predict(torch.Tensor(mesh_input))
|
_components = pl_module.proto_layer._components
|
||||||
y_pred = y_pred.reshape(xx.shape)
|
y_pred = pl_module.predict(
|
||||||
|
torch.Tensor(mesh_input).type_as(_components))
|
||||||
|
y_pred = y_pred.cpu().reshape(xx.shape)
|
||||||
ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
|
ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
|
||||||
|
|
||||||
self.log_and_display(trainer, pl_module)
|
self.log_and_display(trainer, pl_module)
|
||||||
@ -398,20 +400,24 @@ class VisSiameseGLVQ2D(Vis2DAbstract):
|
|||||||
protos = pl_module.prototypes
|
protos = pl_module.prototypes
|
||||||
plabels = pl_module.prototype_labels
|
plabels = pl_module.prototype_labels
|
||||||
x_train, y_train = self.x_train, self.y_train
|
x_train, y_train = self.x_train, self.y_train
|
||||||
x_train = pl_module.backbone(torch.Tensor(x_train)).detach()
|
x_train = pl_module.backbone(
|
||||||
|
torch.Tensor(x_train).to(pl_module.device)).cpu().detach()
|
||||||
if self.map_protos:
|
if self.map_protos:
|
||||||
protos = pl_module.backbone(torch.Tensor(protos)).detach()
|
protos = pl_module.backbone(
|
||||||
|
torch.Tensor(protos).to(pl_module.device)).cpu().detach()
|
||||||
ax = self.setup_ax()
|
ax = self.setup_ax()
|
||||||
self.plot_data(ax, x_train, y_train)
|
self.plot_data(ax, x_train, y_train)
|
||||||
if self.show_protos:
|
#if self.show_protos:
|
||||||
self.plot_protos(ax, protos, plabels)
|
# self.plot_protos(ax, protos, plabels)
|
||||||
x = np.vstack((x_train, protos))
|
# x = np.vstack((x_train, protos))
|
||||||
mesh_input, xx, yy = self.get_mesh_input(x)
|
# mesh_input, xx, yy = self.get_mesh_input(x)
|
||||||
else:
|
#else:
|
||||||
mesh_input, xx, yy = self.get_mesh_input(x_train)
|
# mesh_input, xx, yy = self.get_mesh_input(x_train)
|
||||||
y_pred = pl_module.predict_latent(torch.Tensor(mesh_input))
|
#_components = pl_module.proto_layer._components
|
||||||
y_pred = y_pred.reshape(xx.shape)
|
#y_pred = pl_module.predict(
|
||||||
ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
|
# torch.Tensor(mesh_input).type_as(_components))
|
||||||
|
#y_pred = y_pred.cpu().reshape(xx.shape)
|
||||||
|
#ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
|
||||||
|
|
||||||
self.log_and_display(trainer, pl_module)
|
self.log_and_display(trainer, pl_module)
|
||||||
|
|
||||||
@ -429,8 +435,10 @@ class VisCBC2D(Vis2DAbstract):
|
|||||||
self.plot_protos(ax, protos, "w")
|
self.plot_protos(ax, protos, "w")
|
||||||
x = np.vstack((x_train, protos))
|
x = np.vstack((x_train, protos))
|
||||||
mesh_input, xx, yy = self.get_mesh_input(x)
|
mesh_input, xx, yy = self.get_mesh_input(x)
|
||||||
y_pred = pl_module.predict(torch.Tensor(mesh_input))
|
_components = pl_module.component_layer._components
|
||||||
y_pred = y_pred.reshape(xx.shape)
|
y_pred = pl_module.predict(
|
||||||
|
torch.Tensor(mesh_input).type_as(_components))
|
||||||
|
y_pred = y_pred.cpu().reshape(xx.shape)
|
||||||
|
|
||||||
ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
|
ax.contourf(xx, yy, y_pred, cmap=self.cmap, alpha=0.35)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user