From de61bf7bca92126f44b19b823c399bb8da368186 Mon Sep 17 00:00:00 2001 From: Jensun Ravichandran Date: Thu, 17 Jun 2021 18:10:05 +0200 Subject: [PATCH] [BUGFIX] Fix reasonings initializer dimension bug --- prototorch/core/competitions.py | 1 - prototorch/core/components.py | 12 ++++++++---- prototorch/core/initializers.py | 4 ++-- tests/test_core.py | 26 +++++++++++++------------- 4 files changed, 23 insertions(+), 20 deletions(-) diff --git a/prototorch/core/competitions.py b/prototorch/core/competitions.py index 2a54e10..3e57005 100644 --- a/prototorch/core/competitions.py +++ b/prototorch/core/competitions.py @@ -42,7 +42,6 @@ def cbcc(detections: torch.Tensor, reasonings: torch.Tensor): nk = (1 - A) * B numerator = (detections @ (pk - nk).T) + nk.sum(1) probs = numerator / (pk + nk).sum(1) - # probs = probs.squeeze(0) return probs diff --git a/prototorch/core/components.py b/prototorch/core/components.py index c9edcbb..330e89a 100644 --- a/prototorch/core/components.py +++ b/prototorch/core/components.py @@ -13,6 +13,7 @@ from .initializers import ( AbstractLabelsInitializer, AbstractReasoningsInitializer, LabelsInitializer, + PurePositiveReasoningsInitializer, RandomReasoningsInitializer, ) @@ -308,10 +309,13 @@ class ReasoningComponents(AbstractComponents): three element probability distribution. """ - def __init__(self, distribution: Union[dict, list, tuple], - components_initializer: AbstractComponentsInitializer, - reasonings_initializer: AbstractReasoningsInitializer, - **kwargs): + def __init__( + self, + distribution: Union[dict, list, tuple], + components_initializer: AbstractComponentsInitializer, + reasonings_initializer: + AbstractReasoningsInitializer = PurePositiveReasoningsInitializer(), + **kwargs): super().__init__(**kwargs) self.add_components(distribution, components_initializer, reasonings_initializer) diff --git a/prototorch/core/initializers.py b/prototorch/core/initializers.py index f5d2743..7041cbb 100644 --- a/prototorch/core/initializers.py +++ b/prototorch/core/initializers.py @@ -296,7 +296,7 @@ class OneHotLabelsInitializer(LabelsInitializer): # Reasonings class AbstractReasoningsInitializer(ABC): """Abstract class for all reasonings initializers.""" - def __init__(self, components_first=True): + def __init__(self, components_first: bool = True): self.components_first = components_first def compute_shape(self, distribution): @@ -375,7 +375,7 @@ class PurePositiveReasoningsInitializer(AbstractReasoningsInitializer): num_components, num_classes, _ = self.compute_shape(distribution) A = OneHotLabelsInitializer().generate(distribution) B = torch.zeros(num_components, num_classes) - reasonings = torch.stack([A, B]).permute(2, 1, 0) + reasonings = torch.stack([A, B], dim=-1) reasonings = self.generate_end_hook(reasonings) return reasonings diff --git a/tests/test_core.py b/tests/test_core.py index f949037..d007f9b 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -220,13 +220,6 @@ def test_ones_reasonings_init(): assert torch.allclose(reasonings, torch.zeros(6, 3, 2)) -def test_random_reasonings_init_channels_not_first(): - r = pt.initializers.RandomReasoningsInitializer(components_first=False) - reasonings = r.generate(distribution=[1, 2]) - assert reasonings.shape[0] == 2 - assert reasonings.shape[-1] == 3 - - def test_pure_positive_reasonings_init_one_per_class(): r = pt.initializers.PurePositiveReasoningsInitializer( components_first=False) @@ -234,13 +227,20 @@ def test_pure_positive_reasonings_init_one_per_class(): assert torch.allclose(reasonings[0], torch.eye(4)) -def test_pure_positive_reasonings_init_unrepresented_class(): - r = pt.initializers.PurePositiveReasoningsInitializer( - components_first=False) - reasonings = r.generate(distribution=[1, 0, 1]) +def test_pure_positive_reasonings_init_unrepresented_classes(): + r = pt.initializers.PurePositiveReasoningsInitializer() + reasonings = r.generate(distribution=[9, 0, 0, 0]) + assert reasonings.shape[0] == 9 + assert reasonings.shape[1] == 4 + assert reasonings.shape[2] == 2 + + +def test_random_reasonings_init_channels_not_first(): + r = pt.initializers.RandomReasoningsInitializer(components_first=False) + reasonings = r.generate(distribution=[0, 0, 0, 1]) assert reasonings.shape[0] == 2 - assert reasonings.shape[1] == 2 - assert reasonings.shape[2] == 3 + assert reasonings.shape[1] == 4 + assert reasonings.shape[2] == 1 # Transform initializers