fix: accuracy as torchmetric fixed

This commit is contained in:
Alexander Engelsberger 2022-09-21 10:22:35 +02:00
parent 16ca409f07
commit ba50dfba50
No known key found for this signature in database
GPG Key ID: DE8669706B6AC2E7
2 changed files with 9 additions and 1 deletions

View File

@ -97,6 +97,13 @@ def main():
step=Steps.VALIDATION, step=Steps.VALIDATION,
) )
accuracy = LogTorchmetricCallback(
'validation_accuracy',
torchmetrics.Accuracy,
num_classes=3,
step=Steps.VALIDATION,
)
es = EarlyStopping( es = EarlyStopping(
monitor=stopping_criterion.name, monitor=stopping_criterion.name,
mode="max", mode="max",
@ -111,6 +118,7 @@ def main():
callbacks=[ callbacks=[
vis, vis,
recall, recall,
accuracy,
stopping_criterion, stopping_criterion,
es, es,
PlotLambdaMatrixToTensorboard(), PlotLambdaMatrixToTensorboard(),

View File

@ -237,7 +237,7 @@ class BaseYArchitecture(pl.LightningModule):
_, y = batch _, y = batch
for metric in self.registered_metrics[step]: for metric in self.registered_metrics[step]:
instance = self.registered_metrics[step][metric].to(self.device) instance = self.registered_metrics[step][metric].to(self.device)
instance(y, preds) instance(y, preds.reshape(y.shape))
def update_metrics_epoch(self, step): def update_metrics_epoch(self, step):
for metric in self.registered_metrics[step]: for metric in self.registered_metrics[step]: