diff --git a/prototorch/models/callbacks.py b/prototorch/models/callbacks.py index ef3fb77..d088c7d 100644 --- a/prototorch/models/callbacks.py +++ b/prototorch/models/callbacks.py @@ -44,7 +44,7 @@ class PruneLoserPrototypes(pl.Callback): if self.verbose: print(f"\nPrototype win ratios: {ratios}") print(f"Pruning prototypes at: {to_prune}") - print(f"Corresponding labels are: {prune_labels}") + print(f"Corresponding labels are: {prune_labels.tolist()}") cur_num_protos = pl_module.num_prototypes pl_module.remove_prototypes(indices=to_prune) if self.replace: