[FEATURE] Add standalone reasonings and CBC competition
This commit is contained in:
		@@ -14,6 +14,7 @@ from .core import (
 | 
				
			|||||||
    components,
 | 
					    components,
 | 
				
			||||||
    distances,
 | 
					    distances,
 | 
				
			||||||
    initializers,
 | 
					    initializers,
 | 
				
			||||||
 | 
					    similarities,
 | 
				
			||||||
    losses,
 | 
					    losses,
 | 
				
			||||||
    pooling,
 | 
					    pooling,
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
@@ -31,6 +32,7 @@ __all_core__ = [
 | 
				
			|||||||
    "losses",
 | 
					    "losses",
 | 
				
			||||||
    "nn",
 | 
					    "nn",
 | 
				
			||||||
    "pooling",
 | 
					    "pooling",
 | 
				
			||||||
 | 
					    "similarities",
 | 
				
			||||||
    "utils",
 | 
					    "utils",
 | 
				
			||||||
]
 | 
					]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -28,6 +28,24 @@ def knnc(distances: torch.Tensor,
 | 
				
			|||||||
    return winning_labels
 | 
					    return winning_labels
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def cbcc(detections: torch.Tensor, reasonings: torch.Tensor):
 | 
				
			||||||
 | 
					    """Classification-By-Components Competition.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Returns probability distributions over the classes.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    `detections` must be of shape [batch_size, num_components].
 | 
				
			||||||
 | 
					    `reasonings` must be of shape [num_components, num_classes, 2].
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    A, B = reasonings.permute(2, 1, 0).clamp(0, 1)
 | 
				
			||||||
 | 
					    pk = A
 | 
				
			||||||
 | 
					    nk = (1 - A) * B
 | 
				
			||||||
 | 
					    numerator = (detections @ (pk - nk).T) + nk.sum(1)
 | 
				
			||||||
 | 
					    probs = numerator / (pk + nk).sum(1)
 | 
				
			||||||
 | 
					    # probs = probs.squeeze(0)
 | 
				
			||||||
 | 
					    return probs
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class WTAC(torch.nn.Module):
 | 
					class WTAC(torch.nn.Module):
 | 
				
			||||||
    """Winner-Takes-All-Competition Layer.
 | 
					    """Winner-Takes-All-Competition Layer.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -63,3 +81,13 @@ class KNNC(torch.nn.Module):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    def extra_repr(self):
 | 
					    def extra_repr(self):
 | 
				
			||||||
        return f"k: {self.k}"
 | 
					        return f"k: {self.k}"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class CBCC(torch.nn.Module):
 | 
				
			||||||
 | 
					    """Classification-By-Components Competition.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Thin wrapper over the `cbcc` function.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    def forward(self, detections, reasonings):
 | 
				
			||||||
 | 
					        return cbcc(detections, reasonings)
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -13,6 +13,7 @@ from .initializers import (
 | 
				
			|||||||
    AbstractLabelsInitializer,
 | 
					    AbstractLabelsInitializer,
 | 
				
			||||||
    AbstractReasoningsInitializer,
 | 
					    AbstractReasoningsInitializer,
 | 
				
			||||||
    LabelsInitializer,
 | 
					    LabelsInitializer,
 | 
				
			||||||
 | 
					    RandomReasoningsInitializer,
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -112,7 +113,7 @@ class AbstractLabels(torch.nn.Module):
 | 
				
			|||||||
    """Abstract class for all labels modules."""
 | 
					    """Abstract class for all labels modules."""
 | 
				
			||||||
    @property
 | 
					    @property
 | 
				
			||||||
    def labels(self):
 | 
					    def labels(self):
 | 
				
			||||||
        return self._labels
 | 
					        return self._labels.cpu()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @property
 | 
					    @property
 | 
				
			||||||
    def num_labels(self):
 | 
					    def num_labels(self):
 | 
				
			||||||
@@ -174,6 +175,10 @@ class Labels(AbstractLabels):
 | 
				
			|||||||
        self._register_labels(_labels)
 | 
					        self._register_labels(_labels)
 | 
				
			||||||
        return mask
 | 
					        return mask
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def forward(self):
 | 
				
			||||||
 | 
					        """Simply return the labels."""
 | 
				
			||||||
 | 
					        return self._labels
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class LabeledComponents(AbstractComponents):
 | 
					class LabeledComponents(AbstractComponents):
 | 
				
			||||||
    """A set of adaptable components and corresponding unadaptable labels."""
 | 
					    """A set of adaptable components and corresponding unadaptable labels."""
 | 
				
			||||||
@@ -188,11 +193,6 @@ class LabeledComponents(AbstractComponents):
 | 
				
			|||||||
        self.add_components(distribution, components_initializer,
 | 
					        self.add_components(distribution, components_initializer,
 | 
				
			||||||
                            labels_initializer)
 | 
					                            labels_initializer)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @property
 | 
					 | 
				
			||||||
    def labels(self):
 | 
					 | 
				
			||||||
        """Tensor containing the component labels."""
 | 
					 | 
				
			||||||
        return self._labels
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    @property
 | 
					    @property
 | 
				
			||||||
    def distribution(self):
 | 
					    def distribution(self):
 | 
				
			||||||
        unique, counts = torch.unique(self._labels,
 | 
					        unique, counts = torch.unique(self._labels,
 | 
				
			||||||
@@ -200,6 +200,15 @@ class LabeledComponents(AbstractComponents):
 | 
				
			|||||||
                                      return_counts=True)
 | 
					                                      return_counts=True)
 | 
				
			||||||
        return dict(zip(unique.tolist(), counts.tolist()))
 | 
					        return dict(zip(unique.tolist(), counts.tolist()))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @property
 | 
				
			||||||
 | 
					    def num_classes(self):
 | 
				
			||||||
 | 
					        return len(self.distribution.keys())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @property
 | 
				
			||||||
 | 
					    def labels(self):
 | 
				
			||||||
 | 
					        """Tensor containing the component labels."""
 | 
				
			||||||
 | 
					        return self._labels.cpu()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def _register_labels(self, labels):
 | 
					    def _register_labels(self, labels):
 | 
				
			||||||
        self.register_buffer("_labels", labels)
 | 
					        self.register_buffer("_labels", labels)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -236,6 +245,64 @@ class LabeledComponents(AbstractComponents):
 | 
				
			|||||||
        return self._components, self._labels
 | 
					        return self._components, self._labels
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class Reasonings(torch.nn.Module):
 | 
				
			||||||
 | 
					    """A set of standalone reasoning matrices.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    The `reasonings` tensor is of shape [num_components, num_classes, 2].
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    def __init__(self,
 | 
				
			||||||
 | 
					                 distribution: Union[dict, list, tuple],
 | 
				
			||||||
 | 
					                 initializer:
 | 
				
			||||||
 | 
					                 AbstractReasoningsInitializer = RandomReasoningsInitializer(),
 | 
				
			||||||
 | 
					                 **kwargs):
 | 
				
			||||||
 | 
					        super().__init__(**kwargs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @property
 | 
				
			||||||
 | 
					    def num_classes(self):
 | 
				
			||||||
 | 
					        return self._reasonings.shape[1]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # @property
 | 
				
			||||||
 | 
					    # def reasonings(self):
 | 
				
			||||||
 | 
					    #     """Tensor containing the reasoning matrices."""
 | 
				
			||||||
 | 
					    #     return self._reasonings.detach().cpu()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @property
 | 
				
			||||||
 | 
					    def reasonings(self):
 | 
				
			||||||
 | 
					        with torch.no_grad():
 | 
				
			||||||
 | 
					            A, B = self._reasonings.permute(2, 1, 0).clamp(0, 1)
 | 
				
			||||||
 | 
					            pk = A
 | 
				
			||||||
 | 
					            nk = (1 - pk) * B
 | 
				
			||||||
 | 
					            ik = 1 - pk - nk
 | 
				
			||||||
 | 
					            img = torch.cat([pk, nk, ik], dim=0).permute(1, 0, 2)
 | 
				
			||||||
 | 
					        return img.unsqueeze(1).cpu()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def _register_reasonings(self, reasonings):
 | 
				
			||||||
 | 
					        self.register_buffer("_reasonings", reasonings)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def add_reasonings(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        distribution: Union[dict, list, tuple],
 | 
				
			||||||
 | 
					        initializer:
 | 
				
			||||||
 | 
					        AbstractReasoningsInitializer = RandomReasoningsInitializer()):
 | 
				
			||||||
 | 
					        """Generate and add new reasonings."""
 | 
				
			||||||
 | 
					        assert validate_initializer(initializer, AbstractReasoningsInitializer)
 | 
				
			||||||
 | 
					        _reasonings, new_reasonings = gencat(self, "_reasonings", initializer,
 | 
				
			||||||
 | 
					                                             distribution)
 | 
				
			||||||
 | 
					        self._register_reasonings(_reasonings)
 | 
				
			||||||
 | 
					        return new_reasonings
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def remove_reasonings(self, indices):
 | 
				
			||||||
 | 
					        """Remove reasonings at specified indices."""
 | 
				
			||||||
 | 
					        _reasonings, mask = removeind(self, "_reasonings", indices)
 | 
				
			||||||
 | 
					        self._register_reasonings(_reasonings)
 | 
				
			||||||
 | 
					        return mask
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def forward(self):
 | 
				
			||||||
 | 
					        """Simply return the reasonings."""
 | 
				
			||||||
 | 
					        return self._reasonings
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class ReasoningComponents(AbstractComponents):
 | 
					class ReasoningComponents(AbstractComponents):
 | 
				
			||||||
    """A set of components and a corresponding adapatable reasoning matrices.
 | 
					    """A set of components and a corresponding adapatable reasoning matrices.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -260,13 +327,8 @@ class ReasoningComponents(AbstractComponents):
 | 
				
			|||||||
                            reasonings_initializer)
 | 
					                            reasonings_initializer)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @property
 | 
					    @property
 | 
				
			||||||
    def reasonings(self):
 | 
					    def num_classes(self):
 | 
				
			||||||
        """Returns Reasoning Matrix.
 | 
					        return self._reasonings.shape[1]
 | 
				
			||||||
 | 
					 | 
				
			||||||
        Dimension NxCx2
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        """
 | 
					 | 
				
			||||||
        return self._reasonings.detach().cpu()
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def _register_reasonings(self, reasonings):
 | 
					    def _register_reasonings(self, reasonings):
 | 
				
			||||||
        self.register_parameter("_reasonings", Parameter(reasonings))
 | 
					        self.register_parameter("_reasonings", Parameter(reasonings))
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user