From 11b3e53ecb61402876bfd99ebf09bfbdc433c59b Mon Sep 17 00:00:00 2001 From: Jensun Ravichandran Date: Fri, 7 May 2021 15:45:37 +0200 Subject: [PATCH] Return prototypes as torch tensor --- prototorch/models/abstract.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/prototorch/models/abstract.py b/prototorch/models/abstract.py index dcc89a9..562b6d3 100644 --- a/prototorch/models/abstract.py +++ b/prototorch/models/abstract.py @@ -20,4 +20,4 @@ class AbstractLightningModel(pl.LightningModule): class AbstractPrototypeModel(AbstractLightningModel): @property def prototypes(self): - return self.proto_layer.components.detach().numpy() + return self.proto_layer.components.detach().cpu()