fix: accuracy as torchmetric fixed

This commit is contained in:
Alexander Engelsberger
2022-09-21 10:22:35 +02:00
parent 16ca409f07
commit ba50dfba50
2 changed files with 9 additions and 1 deletions

View File

@@ -237,7 +237,7 @@ class BaseYArchitecture(pl.LightningModule):
_, y = batch
for metric in self.registered_metrics[step]:
instance = self.registered_metrics[step][metric].to(self.device)
instance(y, preds)
instance(y, preds.reshape(y.shape))
def update_metrics_epoch(self, step):
for metric in self.registered_metrics[step]: