fix: accuracy as torchmetric fixed
This commit is contained in:
parent
16ca409f07
commit
ba50dfba50
@ -97,6 +97,13 @@ def main():
|
||||
step=Steps.VALIDATION,
|
||||
)
|
||||
|
||||
accuracy = LogTorchmetricCallback(
|
||||
'validation_accuracy',
|
||||
torchmetrics.Accuracy,
|
||||
num_classes=3,
|
||||
step=Steps.VALIDATION,
|
||||
)
|
||||
|
||||
es = EarlyStopping(
|
||||
monitor=stopping_criterion.name,
|
||||
mode="max",
|
||||
@ -111,6 +118,7 @@ def main():
|
||||
callbacks=[
|
||||
vis,
|
||||
recall,
|
||||
accuracy,
|
||||
stopping_criterion,
|
||||
es,
|
||||
PlotLambdaMatrixToTensorboard(),
|
||||
|
@ -237,7 +237,7 @@ class BaseYArchitecture(pl.LightningModule):
|
||||
_, y = batch
|
||||
for metric in self.registered_metrics[step]:
|
||||
instance = self.registered_metrics[step][metric].to(self.device)
|
||||
instance(y, preds)
|
||||
instance(y, preds.reshape(y.shape))
|
||||
|
||||
def update_metrics_epoch(self, step):
|
||||
for metric in self.registered_metrics[step]:
|
||||
|
Loading…
Reference in New Issue
Block a user