Update image visualization

This commit is contained in:
Jensun Ravichandran 2021-05-20 16:07:16 +02:00
parent df061cc2ff
commit 16dc3cf4eb
2 changed files with 46 additions and 13 deletions

View File

@ -64,6 +64,9 @@ if __name__ == "__main__":
nrow=5,
show=False,
tensorboard=True,
random_data=20,
add_embedding=True,
flatten_data=False,
)
# Setup trainer
@ -73,7 +76,7 @@ if __name__ == "__main__":
# kwargs override the cli-arguments
# max_epochs=50,
# overfit_batches=1,
# fast_dev_run=3,
# fast_dev_run=1,
)
# Training loop

View File

@ -20,6 +20,7 @@ class Vis2DAbstract(pl.Callback):
cmap="viridis",
border=0.1,
resolution=100,
flatten_data=True,
axis_off=False,
show_protos=True,
show=True,
@ -31,9 +32,18 @@ class Vis2DAbstract(pl.Callback):
if isinstance(data, Dataset):
x, y = next(iter(DataLoader(data, batch_size=len(data))))
x = x.view(len(data), -1) # flatten
elif isinstance(data, torch.utils.data.DataLoader):
x = torch.tensor([])
y = torch.tensor([])
for x_b, y_b in data:
x = torch.cat([x, x_b])
y = torch.cat([y, y_b])
else:
x, y = data
if flatten_data:
x = x.view(len(data), -1)
self.x_train = x
self.y_train = y
@ -237,11 +247,42 @@ class VisImgComp(Vis2DAbstract):
random_data=0,
dataformats="CHW",
nrow=2,
add_embedding=False,
embedding_data=100,
**kwargs):
super().__init__(*args, **kwargs)
self.random_data = random_data
self.dataformats = dataformats
self.nrow = nrow
self.add_embedding = add_embedding
self.embedding_data = embedding_data
def on_train_start(self, trainer, pl_module):
if self.add_embedding:
ind = np.random.choice(len(self.x_train),
size=self.embedding_data,
replace=False)
data = self.x_train[ind]
tb = pl_module.logger.experiment
# print(f"{data.shape=}")
# print(f"{self.y_train[ind].shape=}")
tb.add_embedding(data.view(len(ind), -1),
label_img=data,
global_step=None,
tag="Data Embedding",
metadata=self.y_train[ind],
metadata_header=None)
if self.random_data:
ind = np.random.choice(len(self.x_train),
size=self.random_data,
replace=False)
data = self.x_train[ind]
grid = torchvision.utils.make_grid(data, nrow=self.nrow)
tb.add_image(tag="Data",
img_tensor=grid,
global_step=None,
dataformats=self.dataformats)
def add_to_tensorboard(self, trainer, pl_module):
tb = pl_module.logger.experiment
@ -255,17 +296,6 @@ class VisImgComp(Vis2DAbstract):
dataformats=self.dataformats,
)
if self.random_data:
ind = np.random.choice(len(self.x_train),
size=self.random_data,
replace=False)
data_img = self.x_train[ind]
grid = torchvision.utils.make_grid(data_img, nrow=self.nrow)
tb.add_image(tag="Data",
img_tensor=grid,
global_step=trainer.current_epoch,
dataformats=self.dataformats)
def on_epoch_end(self, trainer, pl_module):
if not self.precheck(trainer):
return True