[BUGFIX] Fix reasonings initializer dimension bug
This commit is contained in:
parent
ae11fedbf3
commit
de61bf7bca
@ -42,7 +42,6 @@ def cbcc(detections: torch.Tensor, reasonings: torch.Tensor):
|
||||
nk = (1 - A) * B
|
||||
numerator = (detections @ (pk - nk).T) + nk.sum(1)
|
||||
probs = numerator / (pk + nk).sum(1)
|
||||
# probs = probs.squeeze(0)
|
||||
return probs
|
||||
|
||||
|
||||
|
@ -13,6 +13,7 @@ from .initializers import (
|
||||
AbstractLabelsInitializer,
|
||||
AbstractReasoningsInitializer,
|
||||
LabelsInitializer,
|
||||
PurePositiveReasoningsInitializer,
|
||||
RandomReasoningsInitializer,
|
||||
)
|
||||
|
||||
@ -308,10 +309,13 @@ class ReasoningComponents(AbstractComponents):
|
||||
three element probability distribution.
|
||||
|
||||
"""
|
||||
def __init__(self, distribution: Union[dict, list, tuple],
|
||||
components_initializer: AbstractComponentsInitializer,
|
||||
reasonings_initializer: AbstractReasoningsInitializer,
|
||||
**kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
distribution: Union[dict, list, tuple],
|
||||
components_initializer: AbstractComponentsInitializer,
|
||||
reasonings_initializer:
|
||||
AbstractReasoningsInitializer = PurePositiveReasoningsInitializer(),
|
||||
**kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.add_components(distribution, components_initializer,
|
||||
reasonings_initializer)
|
||||
|
@ -296,7 +296,7 @@ class OneHotLabelsInitializer(LabelsInitializer):
|
||||
# Reasonings
|
||||
class AbstractReasoningsInitializer(ABC):
|
||||
"""Abstract class for all reasonings initializers."""
|
||||
def __init__(self, components_first=True):
|
||||
def __init__(self, components_first: bool = True):
|
||||
self.components_first = components_first
|
||||
|
||||
def compute_shape(self, distribution):
|
||||
@ -375,7 +375,7 @@ class PurePositiveReasoningsInitializer(AbstractReasoningsInitializer):
|
||||
num_components, num_classes, _ = self.compute_shape(distribution)
|
||||
A = OneHotLabelsInitializer().generate(distribution)
|
||||
B = torch.zeros(num_components, num_classes)
|
||||
reasonings = torch.stack([A, B]).permute(2, 1, 0)
|
||||
reasonings = torch.stack([A, B], dim=-1)
|
||||
reasonings = self.generate_end_hook(reasonings)
|
||||
return reasonings
|
||||
|
||||
|
@ -220,13 +220,6 @@ def test_ones_reasonings_init():
|
||||
assert torch.allclose(reasonings, torch.zeros(6, 3, 2))
|
||||
|
||||
|
||||
def test_random_reasonings_init_channels_not_first():
|
||||
r = pt.initializers.RandomReasoningsInitializer(components_first=False)
|
||||
reasonings = r.generate(distribution=[1, 2])
|
||||
assert reasonings.shape[0] == 2
|
||||
assert reasonings.shape[-1] == 3
|
||||
|
||||
|
||||
def test_pure_positive_reasonings_init_one_per_class():
|
||||
r = pt.initializers.PurePositiveReasoningsInitializer(
|
||||
components_first=False)
|
||||
@ -234,13 +227,20 @@ def test_pure_positive_reasonings_init_one_per_class():
|
||||
assert torch.allclose(reasonings[0], torch.eye(4))
|
||||
|
||||
|
||||
def test_pure_positive_reasonings_init_unrepresented_class():
|
||||
r = pt.initializers.PurePositiveReasoningsInitializer(
|
||||
components_first=False)
|
||||
reasonings = r.generate(distribution=[1, 0, 1])
|
||||
def test_pure_positive_reasonings_init_unrepresented_classes():
|
||||
r = pt.initializers.PurePositiveReasoningsInitializer()
|
||||
reasonings = r.generate(distribution=[9, 0, 0, 0])
|
||||
assert reasonings.shape[0] == 9
|
||||
assert reasonings.shape[1] == 4
|
||||
assert reasonings.shape[2] == 2
|
||||
|
||||
|
||||
def test_random_reasonings_init_channels_not_first():
|
||||
r = pt.initializers.RandomReasoningsInitializer(components_first=False)
|
||||
reasonings = r.generate(distribution=[0, 0, 0, 1])
|
||||
assert reasonings.shape[0] == 2
|
||||
assert reasonings.shape[1] == 2
|
||||
assert reasonings.shape[2] == 3
|
||||
assert reasonings.shape[1] == 4
|
||||
assert reasonings.shape[2] == 1
|
||||
|
||||
|
||||
# Transform initializers
|
||||
|
Loading…
Reference in New Issue
Block a user