[FEATURE] Add standalone reasonings and CBC competition

This commit is contained in:
Jensun Ravichandran 2021-06-15 15:41:28 +02:00
parent 0f450ed8a0
commit 6e8a52e371
3 changed files with 105 additions and 13 deletions

View File

@ -14,6 +14,7 @@ from .core import (
components,
distances,
initializers,
similarities,
losses,
pooling,
)
@ -31,6 +32,7 @@ __all_core__ = [
"losses",
"nn",
"pooling",
"similarities",
"utils",
]

View File

@ -28,6 +28,24 @@ def knnc(distances: torch.Tensor,
return winning_labels
def cbcc(detections: torch.Tensor, reasonings: torch.Tensor):
"""Classification-By-Components Competition.
Returns probability distributions over the classes.
`detections` must be of shape [batch_size, num_components].
`reasonings` must be of shape [num_components, num_classes, 2].
"""
A, B = reasonings.permute(2, 1, 0).clamp(0, 1)
pk = A
nk = (1 - A) * B
numerator = (detections @ (pk - nk).T) + nk.sum(1)
probs = numerator / (pk + nk).sum(1)
# probs = probs.squeeze(0)
return probs
class WTAC(torch.nn.Module):
"""Winner-Takes-All-Competition Layer.
@ -63,3 +81,13 @@ class KNNC(torch.nn.Module):
def extra_repr(self):
return f"k: {self.k}"
class CBCC(torch.nn.Module):
"""Classification-By-Components Competition.
Thin wrapper over the `cbcc` function.
"""
def forward(self, detections, reasonings):
return cbcc(detections, reasonings)

View File

@ -13,6 +13,7 @@ from .initializers import (
AbstractLabelsInitializer,
AbstractReasoningsInitializer,
LabelsInitializer,
RandomReasoningsInitializer,
)
@ -112,7 +113,7 @@ class AbstractLabels(torch.nn.Module):
"""Abstract class for all labels modules."""
@property
def labels(self):
return self._labels
return self._labels.cpu()
@property
def num_labels(self):
@ -174,6 +175,10 @@ class Labels(AbstractLabels):
self._register_labels(_labels)
return mask
def forward(self):
"""Simply return the labels."""
return self._labels
class LabeledComponents(AbstractComponents):
"""A set of adaptable components and corresponding unadaptable labels."""
@ -188,11 +193,6 @@ class LabeledComponents(AbstractComponents):
self.add_components(distribution, components_initializer,
labels_initializer)
@property
def labels(self):
"""Tensor containing the component labels."""
return self._labels
@property
def distribution(self):
unique, counts = torch.unique(self._labels,
@ -200,6 +200,15 @@ class LabeledComponents(AbstractComponents):
return_counts=True)
return dict(zip(unique.tolist(), counts.tolist()))
@property
def num_classes(self):
return len(self.distribution.keys())
@property
def labels(self):
"""Tensor containing the component labels."""
return self._labels.cpu()
def _register_labels(self, labels):
self.register_buffer("_labels", labels)
@ -236,6 +245,64 @@ class LabeledComponents(AbstractComponents):
return self._components, self._labels
class Reasonings(torch.nn.Module):
"""A set of standalone reasoning matrices.
The `reasonings` tensor is of shape [num_components, num_classes, 2].
"""
def __init__(self,
distribution: Union[dict, list, tuple],
initializer:
AbstractReasoningsInitializer = RandomReasoningsInitializer(),
**kwargs):
super().__init__(**kwargs)
@property
def num_classes(self):
return self._reasonings.shape[1]
# @property
# def reasonings(self):
# """Tensor containing the reasoning matrices."""
# return self._reasonings.detach().cpu()
@property
def reasonings(self):
with torch.no_grad():
A, B = self._reasonings.permute(2, 1, 0).clamp(0, 1)
pk = A
nk = (1 - pk) * B
ik = 1 - pk - nk
img = torch.cat([pk, nk, ik], dim=0).permute(1, 0, 2)
return img.unsqueeze(1).cpu()
def _register_reasonings(self, reasonings):
self.register_buffer("_reasonings", reasonings)
def add_reasonings(
self,
distribution: Union[dict, list, tuple],
initializer:
AbstractReasoningsInitializer = RandomReasoningsInitializer()):
"""Generate and add new reasonings."""
assert validate_initializer(initializer, AbstractReasoningsInitializer)
_reasonings, new_reasonings = gencat(self, "_reasonings", initializer,
distribution)
self._register_reasonings(_reasonings)
return new_reasonings
def remove_reasonings(self, indices):
"""Remove reasonings at specified indices."""
_reasonings, mask = removeind(self, "_reasonings", indices)
self._register_reasonings(_reasonings)
return mask
def forward(self):
"""Simply return the reasonings."""
return self._reasonings
class ReasoningComponents(AbstractComponents):
"""A set of components and a corresponding adapatable reasoning matrices.
@ -260,13 +327,8 @@ class ReasoningComponents(AbstractComponents):
reasonings_initializer)
@property
def reasonings(self):
"""Returns Reasoning Matrix.
Dimension NxCx2
"""
return self._reasonings.detach().cpu()
def num_classes(self):
return self._reasonings.shape[1]
def _register_reasonings(self, reasonings):
self.register_parameter("_reasonings", Parameter(reasonings))