diff --git a/prototorch/models/vis.py b/prototorch/models/vis.py index c50f19c..1e09437 100644 --- a/prototorch/models/vis.py +++ b/prototorch/models/vis.py @@ -253,12 +253,12 @@ class VisImgComp(Vis2DAbstract): self.embedding_data = embedding_data def on_train_start(self, trainer, pl_module): + tb = pl_module.logger.experiment if self.add_embedding: ind = np.random.choice(len(self.x_train), size=self.embedding_data, replace=False) data = self.x_train[ind] - tb = pl_module.logger.experiment # print(f"{data.shape=}") # print(f"{self.y_train[ind].shape=}") tb.add_embedding(data.view(len(ind), -1),