diff --git a/prototorch/__init__.py b/prototorch/__init__.py index d549de2..d0ce2d6 100644 --- a/prototorch/__init__.py +++ b/prototorch/__init__.py @@ -14,6 +14,7 @@ from .core import ( components, distances, initializers, + similarities, losses, pooling, ) @@ -31,6 +32,7 @@ __all_core__ = [ "losses", "nn", "pooling", + "similarities", "utils", ] diff --git a/prototorch/core/competitions.py b/prototorch/core/competitions.py index 2e354b6..2a54e10 100644 --- a/prototorch/core/competitions.py +++ b/prototorch/core/competitions.py @@ -28,6 +28,24 @@ def knnc(distances: torch.Tensor, return winning_labels +def cbcc(detections: torch.Tensor, reasonings: torch.Tensor): + """Classification-By-Components Competition. + + Returns probability distributions over the classes. + + `detections` must be of shape [batch_size, num_components]. + `reasonings` must be of shape [num_components, num_classes, 2]. + + """ + A, B = reasonings.permute(2, 1, 0).clamp(0, 1) + pk = A + nk = (1 - A) * B + numerator = (detections @ (pk - nk).T) + nk.sum(1) + probs = numerator / (pk + nk).sum(1) + # probs = probs.squeeze(0) + return probs + + class WTAC(torch.nn.Module): """Winner-Takes-All-Competition Layer. @@ -63,3 +81,13 @@ class KNNC(torch.nn.Module): def extra_repr(self): return f"k: {self.k}" + + +class CBCC(torch.nn.Module): + """Classification-By-Components Competition. + + Thin wrapper over the `cbcc` function. + + """ + def forward(self, detections, reasonings): + return cbcc(detections, reasonings) diff --git a/prototorch/core/components.py b/prototorch/core/components.py index f1694ab..d497cdf 100644 --- a/prototorch/core/components.py +++ b/prototorch/core/components.py @@ -13,6 +13,7 @@ from .initializers import ( AbstractLabelsInitializer, AbstractReasoningsInitializer, LabelsInitializer, + RandomReasoningsInitializer, ) @@ -112,7 +113,7 @@ class AbstractLabels(torch.nn.Module): """Abstract class for all labels modules.""" @property def labels(self): - return self._labels + return self._labels.cpu() @property def num_labels(self): @@ -174,6 +175,10 @@ class Labels(AbstractLabels): self._register_labels(_labels) return mask + def forward(self): + """Simply return the labels.""" + return self._labels + class LabeledComponents(AbstractComponents): """A set of adaptable components and corresponding unadaptable labels.""" @@ -188,11 +193,6 @@ class LabeledComponents(AbstractComponents): self.add_components(distribution, components_initializer, labels_initializer) - @property - def labels(self): - """Tensor containing the component labels.""" - return self._labels - @property def distribution(self): unique, counts = torch.unique(self._labels, @@ -200,6 +200,15 @@ class LabeledComponents(AbstractComponents): return_counts=True) return dict(zip(unique.tolist(), counts.tolist())) + @property + def num_classes(self): + return len(self.distribution.keys()) + + @property + def labels(self): + """Tensor containing the component labels.""" + return self._labels.cpu() + def _register_labels(self, labels): self.register_buffer("_labels", labels) @@ -236,6 +245,64 @@ class LabeledComponents(AbstractComponents): return self._components, self._labels +class Reasonings(torch.nn.Module): + """A set of standalone reasoning matrices. + + The `reasonings` tensor is of shape [num_components, num_classes, 2]. + + """ + def __init__(self, + distribution: Union[dict, list, tuple], + initializer: + AbstractReasoningsInitializer = RandomReasoningsInitializer(), + **kwargs): + super().__init__(**kwargs) + + @property + def num_classes(self): + return self._reasonings.shape[1] + + # @property + # def reasonings(self): + # """Tensor containing the reasoning matrices.""" + # return self._reasonings.detach().cpu() + + @property + def reasonings(self): + with torch.no_grad(): + A, B = self._reasonings.permute(2, 1, 0).clamp(0, 1) + pk = A + nk = (1 - pk) * B + ik = 1 - pk - nk + img = torch.cat([pk, nk, ik], dim=0).permute(1, 0, 2) + return img.unsqueeze(1).cpu() + + def _register_reasonings(self, reasonings): + self.register_buffer("_reasonings", reasonings) + + def add_reasonings( + self, + distribution: Union[dict, list, tuple], + initializer: + AbstractReasoningsInitializer = RandomReasoningsInitializer()): + """Generate and add new reasonings.""" + assert validate_initializer(initializer, AbstractReasoningsInitializer) + _reasonings, new_reasonings = gencat(self, "_reasonings", initializer, + distribution) + self._register_reasonings(_reasonings) + return new_reasonings + + def remove_reasonings(self, indices): + """Remove reasonings at specified indices.""" + _reasonings, mask = removeind(self, "_reasonings", indices) + self._register_reasonings(_reasonings) + return mask + + def forward(self): + """Simply return the reasonings.""" + return self._reasonings + + class ReasoningComponents(AbstractComponents): """A set of components and a corresponding adapatable reasoning matrices. @@ -260,13 +327,8 @@ class ReasoningComponents(AbstractComponents): reasonings_initializer) @property - def reasonings(self): - """Returns Reasoning Matrix. - - Dimension NxCx2 - - """ - return self._reasonings.detach().cpu() + def num_classes(self): + return self._reasonings.shape[1] def _register_reasonings(self, reasonings): self.register_parameter("_reasonings", Parameter(reasonings))