[BUGFIX] Fix image visualization for some parameter combination
image visualization was broken if add_embeding was False, but data visualization was on.
This commit is contained in:
parent
e209bf73d5
commit
7379c61966
@ -253,12 +253,12 @@ class VisImgComp(Vis2DAbstract):
|
||||
self.embedding_data = embedding_data
|
||||
|
||||
def on_train_start(self, trainer, pl_module):
|
||||
tb = pl_module.logger.experiment
|
||||
if self.add_embedding:
|
||||
ind = np.random.choice(len(self.x_train),
|
||||
size=self.embedding_data,
|
||||
replace=False)
|
||||
data = self.x_train[ind]
|
||||
tb = pl_module.logger.experiment
|
||||
# print(f"{data.shape=}")
|
||||
# print(f"{self.y_train[ind].shape=}")
|
||||
tb.add_embedding(data.view(len(ind), -1),
|
||||
|
Loading…
Reference in New Issue
Block a user