[FEATURE] Add property reasoning_matrices
This commit is contained in:
parent
454718cdf5
commit
7763a57058
@ -262,20 +262,10 @@ class Reasonings(torch.nn.Module):
|
|||||||
def num_classes(self):
|
def num_classes(self):
|
||||||
return self._reasonings.shape[1]
|
return self._reasonings.shape[1]
|
||||||
|
|
||||||
# @property
|
|
||||||
# def reasonings(self):
|
|
||||||
# """Tensor containing the reasoning matrices."""
|
|
||||||
# return self._reasonings.detach().cpu()
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def reasonings(self):
|
def reasonings(self):
|
||||||
with torch.no_grad():
|
"""Tensor containing the reasoning matrices."""
|
||||||
A, B = self._reasonings.permute(2, 1, 0).clamp(0, 1)
|
return self._reasonings.detach().cpu()
|
||||||
pk = A
|
|
||||||
nk = (1 - pk) * B
|
|
||||||
ik = 1 - pk - nk
|
|
||||||
img = torch.cat([pk, nk, ik], dim=0).permute(1, 0, 2)
|
|
||||||
return img.unsqueeze(1).cpu()
|
|
||||||
|
|
||||||
def _register_reasonings(self, reasonings):
|
def _register_reasonings(self, reasonings):
|
||||||
self.register_buffer("_reasonings", reasonings)
|
self.register_buffer("_reasonings", reasonings)
|
||||||
@ -330,6 +320,22 @@ class ReasoningComponents(AbstractComponents):
|
|||||||
def num_classes(self):
|
def num_classes(self):
|
||||||
return self._reasonings.shape[1]
|
return self._reasonings.shape[1]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def reasonings(self):
|
||||||
|
"""Tensor containing the reasoning matrices."""
|
||||||
|
return self._reasonings.detach().cpu()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def reasoning_matrices(self):
|
||||||
|
"""Reasoning matrices for each class."""
|
||||||
|
with torch.no_grad():
|
||||||
|
A, B = self._reasonings.permute(2, 1, 0).clamp(0, 1)
|
||||||
|
pk = A
|
||||||
|
nk = (1 - pk) * B
|
||||||
|
ik = 1 - pk - nk
|
||||||
|
matrices = torch.stack([pk, nk, ik], dim=-1).permute(1, 2, 0)
|
||||||
|
return matrices.cpu()
|
||||||
|
|
||||||
def _register_reasonings(self, reasonings):
|
def _register_reasonings(self, reasonings):
|
||||||
self.register_parameter("_reasonings", Parameter(reasonings))
|
self.register_parameter("_reasonings", Parameter(reasonings))
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user