Stop passing component initializers as hparams
Pass the component initializer as an hparam slows down the script very much. The API has now been changed to pass it as a kwarg to the models instead. The example scripts have also been updated to reflect the new changes. Also, ImageGMLVQ and an example script `gmlvq_mnist.py` that uses it have also been added.
This commit is contained in:
@@ -3,6 +3,7 @@ import os
|
||||
import numpy as np
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
import torchvision
|
||||
from matplotlib import pyplot as plt
|
||||
from matplotlib.offsetbox import AnchoredText
|
||||
from prototorch.utils.celluloid import Camera
|
||||
@@ -270,6 +271,7 @@ class Vis2DAbstract(pl.Callback):
|
||||
border=1,
|
||||
resolution=50,
|
||||
show_protos=True,
|
||||
show=True,
|
||||
tensorboard=False,
|
||||
show_last_only=False,
|
||||
pause_time=0.1,
|
||||
@@ -290,6 +292,7 @@ class Vis2DAbstract(pl.Callback):
|
||||
self.border = border
|
||||
self.resolution = resolution
|
||||
self.show_protos = show_protos
|
||||
self.show = show
|
||||
self.tensorboard = tensorboard
|
||||
self.show_last_only = show_last_only
|
||||
self.pause_time = pause_time
|
||||
@@ -352,10 +355,11 @@ class Vis2DAbstract(pl.Callback):
|
||||
def log_and_display(self, trainer, pl_module):
|
||||
if self.tensorboard:
|
||||
self.add_to_tensorboard(trainer, pl_module)
|
||||
if not self.block:
|
||||
plt.pause(self.pause_time)
|
||||
else:
|
||||
plt.show(block=True)
|
||||
if self.show:
|
||||
if not self.block:
|
||||
plt.pause(self.pause_time)
|
||||
else:
|
||||
plt.show(block=True)
|
||||
|
||||
def on_train_end(self, trainer, pl_module):
|
||||
plt.show()
|
||||
@@ -458,3 +462,50 @@ class VisNG2D(Vis2DAbstract):
|
||||
)
|
||||
|
||||
self.log_and_display(trainer, pl_module)
|
||||
|
||||
|
||||
class VisImgComp(Vis2DAbstract):
|
||||
def __init__(self,
|
||||
*args,
|
||||
random_data=0,
|
||||
dataformats="CHW",
|
||||
nrow=2,
|
||||
**kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.random_data = random_data
|
||||
self.dataformats = dataformats
|
||||
self.nrow = nrow
|
||||
|
||||
def on_epoch_end(self, trainer, pl_module):
|
||||
if not self.precheck(trainer):
|
||||
return True
|
||||
|
||||
if self.show:
|
||||
components = pl_module.components
|
||||
grid = torchvision.utils.make_grid(components, nrow=self.nrow)
|
||||
plt.imshow(grid.permute((1, 2, 0)).cpu(), cmap=self.cmap)
|
||||
|
||||
self.log_and_display(trainer, pl_module)
|
||||
|
||||
def add_to_tensorboard(self, trainer, pl_module):
|
||||
tb = pl_module.logger.experiment
|
||||
|
||||
components = pl_module.components
|
||||
grid = torchvision.utils.make_grid(components, nrow=self.nrow)
|
||||
tb.add_image(
|
||||
tag="Components",
|
||||
img_tensor=grid,
|
||||
global_step=trainer.current_epoch,
|
||||
dataformats=self.dataformats,
|
||||
)
|
||||
|
||||
if self.random_data:
|
||||
ind = np.random.choice(len(self.x_train),
|
||||
size=self.random_data,
|
||||
replace=False)
|
||||
data_img = self.x_train[ind]
|
||||
grid = torchvision.utils.make_grid(data_img, nrow=self.nrow)
|
||||
tb.add_image(tag="Data",
|
||||
img_tensor=grid,
|
||||
global_step=trainer.current_epoch,
|
||||
dataformats=self.dataformats)
|
||||
|
Reference in New Issue
Block a user