Return prototypes as torch tensor
This commit is contained in:
parent
63a5a98491
commit
11b3e53ecb
@ -20,4 +20,4 @@ class AbstractLightningModel(pl.LightningModule):
|
||||
class AbstractPrototypeModel(AbstractLightningModel):
|
||||
@property
|
||||
def prototypes(self):
|
||||
return self.proto_layer.components.detach().numpy()
|
||||
return self.proto_layer.components.detach().cpu()
|
||||
|
Loading…
Reference in New Issue
Block a user