chore: minor changes and version updates
This commit is contained in:
@@ -1,5 +1,7 @@
|
||||
"""ProtoTorch CBC example using 2D Iris data."""
|
||||
|
||||
import logging
|
||||
|
||||
import torch
|
||||
from matplotlib import pyplot as plt
|
||||
|
||||
@@ -34,7 +36,7 @@ class VisCBC2D():
|
||||
self.resolution = 100
|
||||
self.cmap = "viridis"
|
||||
|
||||
def on_epoch_end(self):
|
||||
def on_train_epoch_end(self):
|
||||
x_train, y_train = self.x_train, self.y_train
|
||||
_components = self.model.components_layer._components.detach()
|
||||
ax = self.fig.gca()
|
||||
@@ -94,5 +96,5 @@ if __name__ == "__main__":
|
||||
correct += (y_pred.argmax(1) == y).float().sum(0)
|
||||
|
||||
acc = 100 * correct / len(train_ds)
|
||||
print(f"Epoch: {epoch} Accuracy: {acc:05.02f}%")
|
||||
vis.on_epoch_end()
|
||||
logging.info(f"Epoch: {epoch} Accuracy: {acc:05.02f}%")
|
||||
vis.on_train_epoch_end()
|
||||
|
Reference in New Issue
Block a user