prototorch_models/examples/cbc_mnist.py

129 lines
3.7 KiB
Python
Raw Normal View History

"""CBC example using the MNIST dataset.
This script also shows how to use Tensorboard for visualizing the prototypes.
"""
import argparse
import pytorch_lightning as pl
import torchvision
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST
2021-04-23 15:27:47 +00:00
from prototorch.models.cbc import CBC, ImageCBC, euclidean_similarity
class VisualizationCallback(pl.Callback):
def __init__(self, to_shape=(-1, 1, 28, 28), nrow=2):
super().__init__()
self.to_shape = to_shape
self.nrow = nrow
def on_epoch_end(self, trainer, pl_module: ImageCBC):
tb = pl_module.logger.experiment
# components
components = pl_module.components
components_img = components.reshape(self.to_shape)
grid = torchvision.utils.make_grid(components_img, nrow=self.nrow)
tb.add_image(
tag="MNIST Components",
img_tensor=grid,
global_step=trainer.current_epoch,
dataformats="CHW",
)
# Reasonings
reasonings = pl_module.reasonings
tb.add_images(
tag="MNIST Reasoning",
img_tensor=reasonings,
global_step=trainer.current_epoch,
dataformats="NCHW",
)
if __name__ == "__main__":
# Arguments
parser = argparse.ArgumentParser()
parser.add_argument("--epochs",
type=int,
default=10,
help="Epochs to train.")
parser.add_argument("--lr",
type=float,
default=0.001,
help="Learning rate.")
parser.add_argument("--batch_size",
type=int,
default=256,
help="Batch size.")
parser.add_argument("--gpus",
type=int,
default=0,
help="Number of GPUs to use.")
parser.add_argument("--ppc",
type=int,
default=1,
help="Prototypes-Per-Class.")
args = parser.parse_args()
# Dataset
mnist_train = MNIST(
"./datasets",
train=True,
download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307, ), (0.3081, ))
]),
)
mnist_test = MNIST(
"./datasets",
train=False,
download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307, ), (0.3081, ))
]),
)
# Dataloaders
2021-04-23 15:27:47 +00:00
train_loader = DataLoader(mnist_train, batch_size=32)
test_loader = DataLoader(mnist_test, batch_size=32)
# Grab the full dataset to warm-start prototypes
x, y = next(iter(DataLoader(mnist_train, batch_size=len(mnist_train))))
x = x.view(len(mnist_train), -1)
# Hyperparameters
hparams = dict(
input_dim=28 * 28,
nclasses=10,
prototypes_per_class=args.ppc,
prototype_initializer="randn",
2021-04-23 15:27:47 +00:00
lr=0.01,
similarity=euclidean_similarity,
)
# Initialize the model
2021-04-23 15:27:47 +00:00
model = CBC(hparams, data=[x, y])
# Model summary
print(model)
# Callbacks
vis = VisualizationCallback(to_shape=(-1, 1, 28, 28), nrow=args.ppc)
# Setup trainer
trainer = pl.Trainer(
gpus=args.gpus, # change to use GPUs for training
max_epochs=args.epochs,
callbacks=[vis],
track_grad_norm=2,
# accelerator="ddp_cpu", # DEBUG-ONLY
# num_processes=2, # DEBUG-ONLY
)
# Training loop
trainer.fit(model, train_loader, test_loader)