[BUGFIX] Fix reasonings initializer dimension bug

This commit is contained in:
Jensun Ravichandran 2021-06-17 18:10:05 +02:00
parent ae11fedbf3
commit de61bf7bca
4 changed files with 23 additions and 20 deletions

View File

@ -42,7 +42,6 @@ def cbcc(detections: torch.Tensor, reasonings: torch.Tensor):
nk = (1 - A) * B
numerator = (detections @ (pk - nk).T) + nk.sum(1)
probs = numerator / (pk + nk).sum(1)
# probs = probs.squeeze(0)
return probs

View File

@ -13,6 +13,7 @@ from .initializers import (
AbstractLabelsInitializer,
AbstractReasoningsInitializer,
LabelsInitializer,
PurePositiveReasoningsInitializer,
RandomReasoningsInitializer,
)
@ -308,10 +309,13 @@ class ReasoningComponents(AbstractComponents):
three element probability distribution.
"""
def __init__(self, distribution: Union[dict, list, tuple],
components_initializer: AbstractComponentsInitializer,
reasonings_initializer: AbstractReasoningsInitializer,
**kwargs):
def __init__(
self,
distribution: Union[dict, list, tuple],
components_initializer: AbstractComponentsInitializer,
reasonings_initializer:
AbstractReasoningsInitializer = PurePositiveReasoningsInitializer(),
**kwargs):
super().__init__(**kwargs)
self.add_components(distribution, components_initializer,
reasonings_initializer)

View File

@ -296,7 +296,7 @@ class OneHotLabelsInitializer(LabelsInitializer):
# Reasonings
class AbstractReasoningsInitializer(ABC):
"""Abstract class for all reasonings initializers."""
def __init__(self, components_first=True):
def __init__(self, components_first: bool = True):
self.components_first = components_first
def compute_shape(self, distribution):
@ -375,7 +375,7 @@ class PurePositiveReasoningsInitializer(AbstractReasoningsInitializer):
num_components, num_classes, _ = self.compute_shape(distribution)
A = OneHotLabelsInitializer().generate(distribution)
B = torch.zeros(num_components, num_classes)
reasonings = torch.stack([A, B]).permute(2, 1, 0)
reasonings = torch.stack([A, B], dim=-1)
reasonings = self.generate_end_hook(reasonings)
return reasonings

View File

@ -220,13 +220,6 @@ def test_ones_reasonings_init():
assert torch.allclose(reasonings, torch.zeros(6, 3, 2))
def test_random_reasonings_init_channels_not_first():
r = pt.initializers.RandomReasoningsInitializer(components_first=False)
reasonings = r.generate(distribution=[1, 2])
assert reasonings.shape[0] == 2
assert reasonings.shape[-1] == 3
def test_pure_positive_reasonings_init_one_per_class():
r = pt.initializers.PurePositiveReasoningsInitializer(
components_first=False)
@ -234,13 +227,20 @@ def test_pure_positive_reasonings_init_one_per_class():
assert torch.allclose(reasonings[0], torch.eye(4))
def test_pure_positive_reasonings_init_unrepresented_class():
r = pt.initializers.PurePositiveReasoningsInitializer(
components_first=False)
reasonings = r.generate(distribution=[1, 0, 1])
def test_pure_positive_reasonings_init_unrepresented_classes():
r = pt.initializers.PurePositiveReasoningsInitializer()
reasonings = r.generate(distribution=[9, 0, 0, 0])
assert reasonings.shape[0] == 9
assert reasonings.shape[1] == 4
assert reasonings.shape[2] == 2
def test_random_reasonings_init_channels_not_first():
r = pt.initializers.RandomReasoningsInitializer(components_first=False)
reasonings = r.generate(distribution=[0, 0, 0, 1])
assert reasonings.shape[0] == 2
assert reasonings.shape[1] == 2
assert reasonings.shape[2] == 3
assert reasonings.shape[1] == 4
assert reasonings.shape[2] == 1
# Transform initializers