Update image visualization
This commit is contained in:
parent
df061cc2ff
commit
16dc3cf4eb
@ -64,6 +64,9 @@ if __name__ == "__main__":
|
|||||||
nrow=5,
|
nrow=5,
|
||||||
show=False,
|
show=False,
|
||||||
tensorboard=True,
|
tensorboard=True,
|
||||||
|
random_data=20,
|
||||||
|
add_embedding=True,
|
||||||
|
flatten_data=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Setup trainer
|
# Setup trainer
|
||||||
@ -73,7 +76,7 @@ if __name__ == "__main__":
|
|||||||
# kwargs override the cli-arguments
|
# kwargs override the cli-arguments
|
||||||
# max_epochs=50,
|
# max_epochs=50,
|
||||||
# overfit_batches=1,
|
# overfit_batches=1,
|
||||||
# fast_dev_run=3,
|
# fast_dev_run=1,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Training loop
|
# Training loop
|
||||||
|
@ -20,6 +20,7 @@ class Vis2DAbstract(pl.Callback):
|
|||||||
cmap="viridis",
|
cmap="viridis",
|
||||||
border=0.1,
|
border=0.1,
|
||||||
resolution=100,
|
resolution=100,
|
||||||
|
flatten_data=True,
|
||||||
axis_off=False,
|
axis_off=False,
|
||||||
show_protos=True,
|
show_protos=True,
|
||||||
show=True,
|
show=True,
|
||||||
@ -31,9 +32,18 @@ class Vis2DAbstract(pl.Callback):
|
|||||||
|
|
||||||
if isinstance(data, Dataset):
|
if isinstance(data, Dataset):
|
||||||
x, y = next(iter(DataLoader(data, batch_size=len(data))))
|
x, y = next(iter(DataLoader(data, batch_size=len(data))))
|
||||||
x = x.view(len(data), -1) # flatten
|
elif isinstance(data, torch.utils.data.DataLoader):
|
||||||
|
x = torch.tensor([])
|
||||||
|
y = torch.tensor([])
|
||||||
|
for x_b, y_b in data:
|
||||||
|
x = torch.cat([x, x_b])
|
||||||
|
y = torch.cat([y, y_b])
|
||||||
else:
|
else:
|
||||||
x, y = data
|
x, y = data
|
||||||
|
|
||||||
|
if flatten_data:
|
||||||
|
x = x.view(len(data), -1)
|
||||||
|
|
||||||
self.x_train = x
|
self.x_train = x
|
||||||
self.y_train = y
|
self.y_train = y
|
||||||
|
|
||||||
@ -237,11 +247,42 @@ class VisImgComp(Vis2DAbstract):
|
|||||||
random_data=0,
|
random_data=0,
|
||||||
dataformats="CHW",
|
dataformats="CHW",
|
||||||
nrow=2,
|
nrow=2,
|
||||||
|
add_embedding=False,
|
||||||
|
embedding_data=100,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self.random_data = random_data
|
self.random_data = random_data
|
||||||
self.dataformats = dataformats
|
self.dataformats = dataformats
|
||||||
self.nrow = nrow
|
self.nrow = nrow
|
||||||
|
self.add_embedding = add_embedding
|
||||||
|
self.embedding_data = embedding_data
|
||||||
|
|
||||||
|
def on_train_start(self, trainer, pl_module):
|
||||||
|
if self.add_embedding:
|
||||||
|
ind = np.random.choice(len(self.x_train),
|
||||||
|
size=self.embedding_data,
|
||||||
|
replace=False)
|
||||||
|
data = self.x_train[ind]
|
||||||
|
tb = pl_module.logger.experiment
|
||||||
|
# print(f"{data.shape=}")
|
||||||
|
# print(f"{self.y_train[ind].shape=}")
|
||||||
|
tb.add_embedding(data.view(len(ind), -1),
|
||||||
|
label_img=data,
|
||||||
|
global_step=None,
|
||||||
|
tag="Data Embedding",
|
||||||
|
metadata=self.y_train[ind],
|
||||||
|
metadata_header=None)
|
||||||
|
|
||||||
|
if self.random_data:
|
||||||
|
ind = np.random.choice(len(self.x_train),
|
||||||
|
size=self.random_data,
|
||||||
|
replace=False)
|
||||||
|
data = self.x_train[ind]
|
||||||
|
grid = torchvision.utils.make_grid(data, nrow=self.nrow)
|
||||||
|
tb.add_image(tag="Data",
|
||||||
|
img_tensor=grid,
|
||||||
|
global_step=None,
|
||||||
|
dataformats=self.dataformats)
|
||||||
|
|
||||||
def add_to_tensorboard(self, trainer, pl_module):
|
def add_to_tensorboard(self, trainer, pl_module):
|
||||||
tb = pl_module.logger.experiment
|
tb = pl_module.logger.experiment
|
||||||
@ -255,17 +296,6 @@ class VisImgComp(Vis2DAbstract):
|
|||||||
dataformats=self.dataformats,
|
dataformats=self.dataformats,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.random_data:
|
|
||||||
ind = np.random.choice(len(self.x_train),
|
|
||||||
size=self.random_data,
|
|
||||||
replace=False)
|
|
||||||
data_img = self.x_train[ind]
|
|
||||||
grid = torchvision.utils.make_grid(data_img, nrow=self.nrow)
|
|
||||||
tb.add_image(tag="Data",
|
|
||||||
img_tensor=grid,
|
|
||||||
global_step=trainer.current_epoch,
|
|
||||||
dataformats=self.dataformats)
|
|
||||||
|
|
||||||
def on_epoch_end(self, trainer, pl_module):
|
def on_epoch_end(self, trainer, pl_module):
|
||||||
if not self.precheck(trainer):
|
if not self.precheck(trainer):
|
||||||
return True
|
return True
|
||||||
|
Loading…
Reference in New Issue
Block a user