Return prototypes as torch tensor

This commit is contained in:
Jensun Ravichandran 2021-05-07 15:45:37 +02:00
parent 63a5a98491
commit 11b3e53ecb

View File

@ -20,4 +20,4 @@ class AbstractLightningModel(pl.LightningModule):
class AbstractPrototypeModel(AbstractLightningModel): class AbstractPrototypeModel(AbstractLightningModel):
@property @property
def prototypes(self): def prototypes(self):
return self.proto_layer.components.detach().numpy() return self.proto_layer.components.detach().cpu()