[BUGFIX] Fix image visualization for some parameter combination

image visualization was broken if add_embeding was False, but data visualization was on.
This commit is contained in:
Alexander Engelsberger 2021-06-03 15:12:51 +02:00
parent e209bf73d5
commit 7379c61966

View File

@ -253,12 +253,12 @@ class VisImgComp(Vis2DAbstract):
self.embedding_data = embedding_data
def on_train_start(self, trainer, pl_module):
tb = pl_module.logger.experiment
if self.add_embedding:
ind = np.random.choice(len(self.x_train),
size=self.embedding_data,
replace=False)
data = self.x_train[ind]
tb = pl_module.logger.experiment
# print(f"{data.shape=}")
# print(f"{self.y_train[ind].shape=}")
tb.add_embedding(data.view(len(ind), -1),