[BUGFIX] Fix image visualization for some parameter combination
image visualization was broken if add_embeding was False, but data visualization was on.
This commit is contained in:
parent
e209bf73d5
commit
7379c61966
@ -253,12 +253,12 @@ class VisImgComp(Vis2DAbstract):
|
|||||||
self.embedding_data = embedding_data
|
self.embedding_data = embedding_data
|
||||||
|
|
||||||
def on_train_start(self, trainer, pl_module):
|
def on_train_start(self, trainer, pl_module):
|
||||||
|
tb = pl_module.logger.experiment
|
||||||
if self.add_embedding:
|
if self.add_embedding:
|
||||||
ind = np.random.choice(len(self.x_train),
|
ind = np.random.choice(len(self.x_train),
|
||||||
size=self.embedding_data,
|
size=self.embedding_data,
|
||||||
replace=False)
|
replace=False)
|
||||||
data = self.x_train[ind]
|
data = self.x_train[ind]
|
||||||
tb = pl_module.logger.experiment
|
|
||||||
# print(f"{data.shape=}")
|
# print(f"{data.shape=}")
|
||||||
# print(f"{self.y_train[ind].shape=}")
|
# print(f"{self.y_train[ind].shape=}")
|
||||||
tb.add_embedding(data.view(len(ind), -1),
|
tb.add_embedding(data.view(len(ind), -1),
|
||||||
|
Loading…
Reference in New Issue
Block a user