diff --git a/prototorch/models/abstract.py b/prototorch/models/abstract.py index dcc89a9..562b6d3 100644 --- a/prototorch/models/abstract.py +++ b/prototorch/models/abstract.py @@ -20,4 +20,4 @@ class AbstractLightningModel(pl.LightningModule): class AbstractPrototypeModel(AbstractLightningModel): @property def prototypes(self): - return self.proto_layer.components.detach().numpy() + return self.proto_layer.components.detach().cpu()