fix: accuracy as torchmetric fixed

This commit is contained in:
Alexander Engelsberger 2022-09-21 10:22:35 +02:00
parent 16ca409f07
commit ba50dfba50
No known key found for this signature in database
GPG Key ID: DE8669706B6AC2E7
2 changed files with 9 additions and 1 deletions

View File

@ -97,6 +97,13 @@ def main():
step=Steps.VALIDATION,
)
accuracy = LogTorchmetricCallback(
'validation_accuracy',
torchmetrics.Accuracy,
num_classes=3,
step=Steps.VALIDATION,
)
es = EarlyStopping(
monitor=stopping_criterion.name,
mode="max",
@ -111,6 +118,7 @@ def main():
callbacks=[
vis,
recall,
accuracy,
stopping_criterion,
es,
PlotLambdaMatrixToTensorboard(),

View File

@ -237,7 +237,7 @@ class BaseYArchitecture(pl.LightningModule):
_, y = batch
for metric in self.registered_metrics[step]:
instance = self.registered_metrics[step][metric].to(self.device)
instance(y, preds)
instance(y, preds.reshape(y.shape))
def update_metrics_epoch(self, step):
for metric in self.registered_metrics[step]: