Stop passing component initializers as hparams

Pass the component initializer as an hparam slows down the script very much. The
API has now been changed to pass it as a kwarg to the models instead.

The example scripts have also been updated to reflect the new changes.

Also, ImageGMLVQ and an example script `gmlvq_mnist.py` that uses it have also
been added.
This commit is contained in:
Jensun Ravichandran
2021-05-12 16:36:22 +02:00
parent 1498c4bde5
commit ca39aa00d5
11 changed files with 172 additions and 21 deletions

View File

@@ -3,6 +3,7 @@ import os
import numpy as np
import pytorch_lightning as pl
import torch
import torchvision
from matplotlib import pyplot as plt
from matplotlib.offsetbox import AnchoredText
from prototorch.utils.celluloid import Camera
@@ -270,6 +271,7 @@ class Vis2DAbstract(pl.Callback):
border=1,
resolution=50,
show_protos=True,
show=True,
tensorboard=False,
show_last_only=False,
pause_time=0.1,
@@ -290,6 +292,7 @@ class Vis2DAbstract(pl.Callback):
self.border = border
self.resolution = resolution
self.show_protos = show_protos
self.show = show
self.tensorboard = tensorboard
self.show_last_only = show_last_only
self.pause_time = pause_time
@@ -352,10 +355,11 @@ class Vis2DAbstract(pl.Callback):
def log_and_display(self, trainer, pl_module):
if self.tensorboard:
self.add_to_tensorboard(trainer, pl_module)
if not self.block:
plt.pause(self.pause_time)
else:
plt.show(block=True)
if self.show:
if not self.block:
plt.pause(self.pause_time)
else:
plt.show(block=True)
def on_train_end(self, trainer, pl_module):
plt.show()
@@ -458,3 +462,50 @@ class VisNG2D(Vis2DAbstract):
)
self.log_and_display(trainer, pl_module)
class VisImgComp(Vis2DAbstract):
def __init__(self,
*args,
random_data=0,
dataformats="CHW",
nrow=2,
**kwargs):
super().__init__(*args, **kwargs)
self.random_data = random_data
self.dataformats = dataformats
self.nrow = nrow
def on_epoch_end(self, trainer, pl_module):
if not self.precheck(trainer):
return True
if self.show:
components = pl_module.components
grid = torchvision.utils.make_grid(components, nrow=self.nrow)
plt.imshow(grid.permute((1, 2, 0)).cpu(), cmap=self.cmap)
self.log_and_display(trainer, pl_module)
def add_to_tensorboard(self, trainer, pl_module):
tb = pl_module.logger.experiment
components = pl_module.components
grid = torchvision.utils.make_grid(components, nrow=self.nrow)
tb.add_image(
tag="Components",
img_tensor=grid,
global_step=trainer.current_epoch,
dataformats=self.dataformats,
)
if self.random_data:
ind = np.random.choice(len(self.x_train),
size=self.random_data,
replace=False)
data_img = self.x_train[ind]
grid = torchvision.utils.make_grid(data_img, nrow=self.nrow)
tb.add_image(tag="Data",
img_tensor=grid,
global_step=trainer.current_epoch,
dataformats=self.dataformats)