From 7763a57058e1e47d27a8ce5dc8e1d32e320ec4c8 Mon Sep 17 00:00:00 2001 From: Jensun Ravichandran Date: Wed, 16 Jun 2021 13:39:09 +0200 Subject: [PATCH] [FEATURE] Add property `reasoning_matrices` --- prototorch/core/components.py | 30 ++++++++++++++++++------------ 1 file changed, 18 insertions(+), 12 deletions(-) diff --git a/prototorch/core/components.py b/prototorch/core/components.py index d497cdf..c9edcbb 100644 --- a/prototorch/core/components.py +++ b/prototorch/core/components.py @@ -262,20 +262,10 @@ class Reasonings(torch.nn.Module): def num_classes(self): return self._reasonings.shape[1] - # @property - # def reasonings(self): - # """Tensor containing the reasoning matrices.""" - # return self._reasonings.detach().cpu() - @property def reasonings(self): - with torch.no_grad(): - A, B = self._reasonings.permute(2, 1, 0).clamp(0, 1) - pk = A - nk = (1 - pk) * B - ik = 1 - pk - nk - img = torch.cat([pk, nk, ik], dim=0).permute(1, 0, 2) - return img.unsqueeze(1).cpu() + """Tensor containing the reasoning matrices.""" + return self._reasonings.detach().cpu() def _register_reasonings(self, reasonings): self.register_buffer("_reasonings", reasonings) @@ -330,6 +320,22 @@ class ReasoningComponents(AbstractComponents): def num_classes(self): return self._reasonings.shape[1] + @property + def reasonings(self): + """Tensor containing the reasoning matrices.""" + return self._reasonings.detach().cpu() + + @property + def reasoning_matrices(self): + """Reasoning matrices for each class.""" + with torch.no_grad(): + A, B = self._reasonings.permute(2, 1, 0).clamp(0, 1) + pk = A + nk = (1 - pk) * B + ik = 1 - pk - nk + matrices = torch.stack([pk, nk, ik], dim=-1).permute(1, 2, 0) + return matrices.cpu() + def _register_reasonings(self, reasonings): self.register_parameter("_reasonings", Parameter(reasonings))