fix: accuracy as torchmetric fixed
This commit is contained in:
parent
16ca409f07
commit
ba50dfba50
@ -97,6 +97,13 @@ def main():
|
|||||||
step=Steps.VALIDATION,
|
step=Steps.VALIDATION,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
accuracy = LogTorchmetricCallback(
|
||||||
|
'validation_accuracy',
|
||||||
|
torchmetrics.Accuracy,
|
||||||
|
num_classes=3,
|
||||||
|
step=Steps.VALIDATION,
|
||||||
|
)
|
||||||
|
|
||||||
es = EarlyStopping(
|
es = EarlyStopping(
|
||||||
monitor=stopping_criterion.name,
|
monitor=stopping_criterion.name,
|
||||||
mode="max",
|
mode="max",
|
||||||
@ -111,6 +118,7 @@ def main():
|
|||||||
callbacks=[
|
callbacks=[
|
||||||
vis,
|
vis,
|
||||||
recall,
|
recall,
|
||||||
|
accuracy,
|
||||||
stopping_criterion,
|
stopping_criterion,
|
||||||
es,
|
es,
|
||||||
PlotLambdaMatrixToTensorboard(),
|
PlotLambdaMatrixToTensorboard(),
|
||||||
|
@ -237,7 +237,7 @@ class BaseYArchitecture(pl.LightningModule):
|
|||||||
_, y = batch
|
_, y = batch
|
||||||
for metric in self.registered_metrics[step]:
|
for metric in self.registered_metrics[step]:
|
||||||
instance = self.registered_metrics[step][metric].to(self.device)
|
instance = self.registered_metrics[step][metric].to(self.device)
|
||||||
instance(y, preds)
|
instance(y, preds.reshape(y.shape))
|
||||||
|
|
||||||
def update_metrics_epoch(self, step):
|
def update_metrics_epoch(self, step):
|
||||||
for metric in self.registered_metrics[step]:
|
for metric in self.registered_metrics[step]:
|
||||||
|
Loading…
Reference in New Issue
Block a user