[FEATURE] Add property reasoning_matrices
This commit is contained in:
parent
454718cdf5
commit
7763a57058
@ -262,20 +262,10 @@ class Reasonings(torch.nn.Module):
|
||||
def num_classes(self):
|
||||
return self._reasonings.shape[1]
|
||||
|
||||
# @property
|
||||
# def reasonings(self):
|
||||
# """Tensor containing the reasoning matrices."""
|
||||
# return self._reasonings.detach().cpu()
|
||||
|
||||
@property
|
||||
def reasonings(self):
|
||||
with torch.no_grad():
|
||||
A, B = self._reasonings.permute(2, 1, 0).clamp(0, 1)
|
||||
pk = A
|
||||
nk = (1 - pk) * B
|
||||
ik = 1 - pk - nk
|
||||
img = torch.cat([pk, nk, ik], dim=0).permute(1, 0, 2)
|
||||
return img.unsqueeze(1).cpu()
|
||||
"""Tensor containing the reasoning matrices."""
|
||||
return self._reasonings.detach().cpu()
|
||||
|
||||
def _register_reasonings(self, reasonings):
|
||||
self.register_buffer("_reasonings", reasonings)
|
||||
@ -330,6 +320,22 @@ class ReasoningComponents(AbstractComponents):
|
||||
def num_classes(self):
|
||||
return self._reasonings.shape[1]
|
||||
|
||||
@property
|
||||
def reasonings(self):
|
||||
"""Tensor containing the reasoning matrices."""
|
||||
return self._reasonings.detach().cpu()
|
||||
|
||||
@property
|
||||
def reasoning_matrices(self):
|
||||
"""Reasoning matrices for each class."""
|
||||
with torch.no_grad():
|
||||
A, B = self._reasonings.permute(2, 1, 0).clamp(0, 1)
|
||||
pk = A
|
||||
nk = (1 - pk) * B
|
||||
ik = 1 - pk - nk
|
||||
matrices = torch.stack([pk, nk, ik], dim=-1).permute(1, 2, 0)
|
||||
return matrices.cpu()
|
||||
|
||||
def _register_reasonings(self, reasonings):
|
||||
self.register_parameter("_reasonings", Parameter(reasonings))
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user