[FEATURE] Add property reasoning_matrices

This commit is contained in:
Jensun Ravichandran 2021-06-16 13:39:09 +02:00
parent 454718cdf5
commit 7763a57058

View File

@ -262,20 +262,10 @@ class Reasonings(torch.nn.Module):
def num_classes(self): def num_classes(self):
return self._reasonings.shape[1] return self._reasonings.shape[1]
# @property
# def reasonings(self):
# """Tensor containing the reasoning matrices."""
# return self._reasonings.detach().cpu()
@property @property
def reasonings(self): def reasonings(self):
with torch.no_grad(): """Tensor containing the reasoning matrices."""
A, B = self._reasonings.permute(2, 1, 0).clamp(0, 1) return self._reasonings.detach().cpu()
pk = A
nk = (1 - pk) * B
ik = 1 - pk - nk
img = torch.cat([pk, nk, ik], dim=0).permute(1, 0, 2)
return img.unsqueeze(1).cpu()
def _register_reasonings(self, reasonings): def _register_reasonings(self, reasonings):
self.register_buffer("_reasonings", reasonings) self.register_buffer("_reasonings", reasonings)
@ -330,6 +320,22 @@ class ReasoningComponents(AbstractComponents):
def num_classes(self): def num_classes(self):
return self._reasonings.shape[1] return self._reasonings.shape[1]
@property
def reasonings(self):
"""Tensor containing the reasoning matrices."""
return self._reasonings.detach().cpu()
@property
def reasoning_matrices(self):
"""Reasoning matrices for each class."""
with torch.no_grad():
A, B = self._reasonings.permute(2, 1, 0).clamp(0, 1)
pk = A
nk = (1 - pk) * B
ik = 1 - pk - nk
matrices = torch.stack([pk, nk, ik], dim=-1).permute(1, 2, 0)
return matrices.cpu()
def _register_reasonings(self, reasonings): def _register_reasonings(self, reasonings):
self.register_parameter("_reasonings", Parameter(reasonings)) self.register_parameter("_reasonings", Parameter(reasonings))