[BUGFIX] Fix reasonings initializer dimension bug
This commit is contained in:
parent
ae11fedbf3
commit
de61bf7bca
@ -42,7 +42,6 @@ def cbcc(detections: torch.Tensor, reasonings: torch.Tensor):
|
|||||||
nk = (1 - A) * B
|
nk = (1 - A) * B
|
||||||
numerator = (detections @ (pk - nk).T) + nk.sum(1)
|
numerator = (detections @ (pk - nk).T) + nk.sum(1)
|
||||||
probs = numerator / (pk + nk).sum(1)
|
probs = numerator / (pk + nk).sum(1)
|
||||||
# probs = probs.squeeze(0)
|
|
||||||
return probs
|
return probs
|
||||||
|
|
||||||
|
|
||||||
|
@ -13,6 +13,7 @@ from .initializers import (
|
|||||||
AbstractLabelsInitializer,
|
AbstractLabelsInitializer,
|
||||||
AbstractReasoningsInitializer,
|
AbstractReasoningsInitializer,
|
||||||
LabelsInitializer,
|
LabelsInitializer,
|
||||||
|
PurePositiveReasoningsInitializer,
|
||||||
RandomReasoningsInitializer,
|
RandomReasoningsInitializer,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -308,9 +309,12 @@ class ReasoningComponents(AbstractComponents):
|
|||||||
three element probability distribution.
|
three element probability distribution.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
def __init__(self, distribution: Union[dict, list, tuple],
|
def __init__(
|
||||||
|
self,
|
||||||
|
distribution: Union[dict, list, tuple],
|
||||||
components_initializer: AbstractComponentsInitializer,
|
components_initializer: AbstractComponentsInitializer,
|
||||||
reasonings_initializer: AbstractReasoningsInitializer,
|
reasonings_initializer:
|
||||||
|
AbstractReasoningsInitializer = PurePositiveReasoningsInitializer(),
|
||||||
**kwargs):
|
**kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self.add_components(distribution, components_initializer,
|
self.add_components(distribution, components_initializer,
|
||||||
|
@ -296,7 +296,7 @@ class OneHotLabelsInitializer(LabelsInitializer):
|
|||||||
# Reasonings
|
# Reasonings
|
||||||
class AbstractReasoningsInitializer(ABC):
|
class AbstractReasoningsInitializer(ABC):
|
||||||
"""Abstract class for all reasonings initializers."""
|
"""Abstract class for all reasonings initializers."""
|
||||||
def __init__(self, components_first=True):
|
def __init__(self, components_first: bool = True):
|
||||||
self.components_first = components_first
|
self.components_first = components_first
|
||||||
|
|
||||||
def compute_shape(self, distribution):
|
def compute_shape(self, distribution):
|
||||||
@ -375,7 +375,7 @@ class PurePositiveReasoningsInitializer(AbstractReasoningsInitializer):
|
|||||||
num_components, num_classes, _ = self.compute_shape(distribution)
|
num_components, num_classes, _ = self.compute_shape(distribution)
|
||||||
A = OneHotLabelsInitializer().generate(distribution)
|
A = OneHotLabelsInitializer().generate(distribution)
|
||||||
B = torch.zeros(num_components, num_classes)
|
B = torch.zeros(num_components, num_classes)
|
||||||
reasonings = torch.stack([A, B]).permute(2, 1, 0)
|
reasonings = torch.stack([A, B], dim=-1)
|
||||||
reasonings = self.generate_end_hook(reasonings)
|
reasonings = self.generate_end_hook(reasonings)
|
||||||
return reasonings
|
return reasonings
|
||||||
|
|
||||||
|
@ -220,13 +220,6 @@ def test_ones_reasonings_init():
|
|||||||
assert torch.allclose(reasonings, torch.zeros(6, 3, 2))
|
assert torch.allclose(reasonings, torch.zeros(6, 3, 2))
|
||||||
|
|
||||||
|
|
||||||
def test_random_reasonings_init_channels_not_first():
|
|
||||||
r = pt.initializers.RandomReasoningsInitializer(components_first=False)
|
|
||||||
reasonings = r.generate(distribution=[1, 2])
|
|
||||||
assert reasonings.shape[0] == 2
|
|
||||||
assert reasonings.shape[-1] == 3
|
|
||||||
|
|
||||||
|
|
||||||
def test_pure_positive_reasonings_init_one_per_class():
|
def test_pure_positive_reasonings_init_one_per_class():
|
||||||
r = pt.initializers.PurePositiveReasoningsInitializer(
|
r = pt.initializers.PurePositiveReasoningsInitializer(
|
||||||
components_first=False)
|
components_first=False)
|
||||||
@ -234,13 +227,20 @@ def test_pure_positive_reasonings_init_one_per_class():
|
|||||||
assert torch.allclose(reasonings[0], torch.eye(4))
|
assert torch.allclose(reasonings[0], torch.eye(4))
|
||||||
|
|
||||||
|
|
||||||
def test_pure_positive_reasonings_init_unrepresented_class():
|
def test_pure_positive_reasonings_init_unrepresented_classes():
|
||||||
r = pt.initializers.PurePositiveReasoningsInitializer(
|
r = pt.initializers.PurePositiveReasoningsInitializer()
|
||||||
components_first=False)
|
reasonings = r.generate(distribution=[9, 0, 0, 0])
|
||||||
reasonings = r.generate(distribution=[1, 0, 1])
|
assert reasonings.shape[0] == 9
|
||||||
|
assert reasonings.shape[1] == 4
|
||||||
|
assert reasonings.shape[2] == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_random_reasonings_init_channels_not_first():
|
||||||
|
r = pt.initializers.RandomReasoningsInitializer(components_first=False)
|
||||||
|
reasonings = r.generate(distribution=[0, 0, 0, 1])
|
||||||
assert reasonings.shape[0] == 2
|
assert reasonings.shape[0] == 2
|
||||||
assert reasonings.shape[1] == 2
|
assert reasonings.shape[1] == 4
|
||||||
assert reasonings.shape[2] == 3
|
assert reasonings.shape[2] == 1
|
||||||
|
|
||||||
|
|
||||||
# Transform initializers
|
# Transform initializers
|
||||||
|
Loading…
Reference in New Issue
Block a user