[BUGFIX] Fix reasonings initializer dimension bug

This commit is contained in:
Jensun Ravichandran 2021-06-17 18:10:05 +02:00
parent ae11fedbf3
commit de61bf7bca
4 changed files with 23 additions and 20 deletions

View File

@ -42,7 +42,6 @@ def cbcc(detections: torch.Tensor, reasonings: torch.Tensor):
nk = (1 - A) * B nk = (1 - A) * B
numerator = (detections @ (pk - nk).T) + nk.sum(1) numerator = (detections @ (pk - nk).T) + nk.sum(1)
probs = numerator / (pk + nk).sum(1) probs = numerator / (pk + nk).sum(1)
# probs = probs.squeeze(0)
return probs return probs

View File

@ -13,6 +13,7 @@ from .initializers import (
AbstractLabelsInitializer, AbstractLabelsInitializer,
AbstractReasoningsInitializer, AbstractReasoningsInitializer,
LabelsInitializer, LabelsInitializer,
PurePositiveReasoningsInitializer,
RandomReasoningsInitializer, RandomReasoningsInitializer,
) )
@ -308,10 +309,13 @@ class ReasoningComponents(AbstractComponents):
three element probability distribution. three element probability distribution.
""" """
def __init__(self, distribution: Union[dict, list, tuple], def __init__(
components_initializer: AbstractComponentsInitializer, self,
reasonings_initializer: AbstractReasoningsInitializer, distribution: Union[dict, list, tuple],
**kwargs): components_initializer: AbstractComponentsInitializer,
reasonings_initializer:
AbstractReasoningsInitializer = PurePositiveReasoningsInitializer(),
**kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.add_components(distribution, components_initializer, self.add_components(distribution, components_initializer,
reasonings_initializer) reasonings_initializer)

View File

@ -296,7 +296,7 @@ class OneHotLabelsInitializer(LabelsInitializer):
# Reasonings # Reasonings
class AbstractReasoningsInitializer(ABC): class AbstractReasoningsInitializer(ABC):
"""Abstract class for all reasonings initializers.""" """Abstract class for all reasonings initializers."""
def __init__(self, components_first=True): def __init__(self, components_first: bool = True):
self.components_first = components_first self.components_first = components_first
def compute_shape(self, distribution): def compute_shape(self, distribution):
@ -375,7 +375,7 @@ class PurePositiveReasoningsInitializer(AbstractReasoningsInitializer):
num_components, num_classes, _ = self.compute_shape(distribution) num_components, num_classes, _ = self.compute_shape(distribution)
A = OneHotLabelsInitializer().generate(distribution) A = OneHotLabelsInitializer().generate(distribution)
B = torch.zeros(num_components, num_classes) B = torch.zeros(num_components, num_classes)
reasonings = torch.stack([A, B]).permute(2, 1, 0) reasonings = torch.stack([A, B], dim=-1)
reasonings = self.generate_end_hook(reasonings) reasonings = self.generate_end_hook(reasonings)
return reasonings return reasonings

View File

@ -220,13 +220,6 @@ def test_ones_reasonings_init():
assert torch.allclose(reasonings, torch.zeros(6, 3, 2)) assert torch.allclose(reasonings, torch.zeros(6, 3, 2))
def test_random_reasonings_init_channels_not_first():
r = pt.initializers.RandomReasoningsInitializer(components_first=False)
reasonings = r.generate(distribution=[1, 2])
assert reasonings.shape[0] == 2
assert reasonings.shape[-1] == 3
def test_pure_positive_reasonings_init_one_per_class(): def test_pure_positive_reasonings_init_one_per_class():
r = pt.initializers.PurePositiveReasoningsInitializer( r = pt.initializers.PurePositiveReasoningsInitializer(
components_first=False) components_first=False)
@ -234,13 +227,20 @@ def test_pure_positive_reasonings_init_one_per_class():
assert torch.allclose(reasonings[0], torch.eye(4)) assert torch.allclose(reasonings[0], torch.eye(4))
def test_pure_positive_reasonings_init_unrepresented_class(): def test_pure_positive_reasonings_init_unrepresented_classes():
r = pt.initializers.PurePositiveReasoningsInitializer( r = pt.initializers.PurePositiveReasoningsInitializer()
components_first=False) reasonings = r.generate(distribution=[9, 0, 0, 0])
reasonings = r.generate(distribution=[1, 0, 1]) assert reasonings.shape[0] == 9
assert reasonings.shape[1] == 4
assert reasonings.shape[2] == 2
def test_random_reasonings_init_channels_not_first():
r = pt.initializers.RandomReasoningsInitializer(components_first=False)
reasonings = r.generate(distribution=[0, 0, 0, 1])
assert reasonings.shape[0] == 2 assert reasonings.shape[0] == 2
assert reasonings.shape[1] == 2 assert reasonings.shape[1] == 4
assert reasonings.shape[2] == 3 assert reasonings.shape[2] == 1
# Transform initializers # Transform initializers