[BUGFIX] Fix reasonings initializer dimension bug

This commit is contained in:
Jensun Ravichandran
2021-06-17 18:10:05 +02:00
parent ae11fedbf3
commit de61bf7bca
4 changed files with 23 additions and 20 deletions

View File

@@ -42,7 +42,6 @@ def cbcc(detections: torch.Tensor, reasonings: torch.Tensor):
nk = (1 - A) * B
numerator = (detections @ (pk - nk).T) + nk.sum(1)
probs = numerator / (pk + nk).sum(1)
# probs = probs.squeeze(0)
return probs

View File

@@ -13,6 +13,7 @@ from .initializers import (
AbstractLabelsInitializer,
AbstractReasoningsInitializer,
LabelsInitializer,
PurePositiveReasoningsInitializer,
RandomReasoningsInitializer,
)
@@ -308,10 +309,13 @@ class ReasoningComponents(AbstractComponents):
three element probability distribution.
"""
def __init__(self, distribution: Union[dict, list, tuple],
components_initializer: AbstractComponentsInitializer,
reasonings_initializer: AbstractReasoningsInitializer,
**kwargs):
def __init__(
self,
distribution: Union[dict, list, tuple],
components_initializer: AbstractComponentsInitializer,
reasonings_initializer:
AbstractReasoningsInitializer = PurePositiveReasoningsInitializer(),
**kwargs):
super().__init__(**kwargs)
self.add_components(distribution, components_initializer,
reasonings_initializer)

View File

@@ -296,7 +296,7 @@ class OneHotLabelsInitializer(LabelsInitializer):
# Reasonings
class AbstractReasoningsInitializer(ABC):
"""Abstract class for all reasonings initializers."""
def __init__(self, components_first=True):
def __init__(self, components_first: bool = True):
self.components_first = components_first
def compute_shape(self, distribution):
@@ -375,7 +375,7 @@ class PurePositiveReasoningsInitializer(AbstractReasoningsInitializer):
num_components, num_classes, _ = self.compute_shape(distribution)
A = OneHotLabelsInitializer().generate(distribution)
B = torch.zeros(num_components, num_classes)
reasonings = torch.stack([A, B]).permute(2, 1, 0)
reasonings = torch.stack([A, B], dim=-1)
reasonings = self.generate_end_hook(reasonings)
return reasonings