From 3d76dffe3cb5c81820827fb36ca1126dd6cb07ae Mon Sep 17 00:00:00 2001 From: Alexander Engelsberger Date: Mon, 21 Jun 2021 14:29:25 +0200 Subject: [PATCH] chore: Allow no-self-use for some class members Classes are used as common interface and connection to pytorch. --- prototorch/core/competitions.py | 6 +++--- prototorch/core/initializers.py | 22 ++++++++++++---------- prototorch/core/pooling.py | 8 ++++---- 3 files changed, 19 insertions(+), 17 deletions(-) diff --git a/prototorch/core/competitions.py b/prototorch/core/competitions.py index 8ac72e7..520561b 100644 --- a/prototorch/core/competitions.py +++ b/prototorch/core/competitions.py @@ -48,7 +48,7 @@ class WTAC(torch.nn.Module): Thin wrapper over the `wtac` function. """ - def forward(self, distances, labels): + def forward(self, distances, labels): # pylint: disable=no-self-use return wtac(distances, labels) @@ -58,7 +58,7 @@ class LTAC(torch.nn.Module): Thin wrapper over the `wtac` function. """ - def forward(self, probs, labels): + def forward(self, probs, labels): # pylint: disable=no-self-use return wtac(-1.0 * probs, labels) @@ -85,5 +85,5 @@ class CBCC(torch.nn.Module): Thin wrapper over the `cbcc` function. """ - def forward(self, detections, reasonings): + def forward(self, detections, reasonings): # pylint: disable=no-self-use return cbcc(detections, reasonings) diff --git a/prototorch/core/initializers.py b/prototorch/core/initializers.py index fa4299c..4518bc7 100644 --- a/prototorch/core/initializers.py +++ b/prototorch/core/initializers.py @@ -303,17 +303,18 @@ class OneHotLabelsInitializer(LabelsInitializer): # Reasonings +def compute_distribution_shape(distribution): + distribution = parse_distribution(distribution) + num_components = sum(distribution.values()) + num_classes = len(distribution.keys()) + return (num_components, num_classes, 2) + + class AbstractReasoningsInitializer(ABC): """Abstract class for all reasonings initializers.""" def __init__(self, components_first: bool = True): self.components_first = components_first - def compute_shape(self, distribution): - distribution = parse_distribution(distribution) - num_components = sum(distribution.values()) - num_classes = len(distribution.keys()) - return (num_components, num_classes, 2) - def generate_end_hook(self, reasonings): if not self.components_first: reasonings = reasonings.permute(2, 1, 0) @@ -349,7 +350,7 @@ class LiteralReasoningsInitializer(AbstractReasoningsInitializer): class ZerosReasoningsInitializer(AbstractReasoningsInitializer): """Reasonings are all initialized with zeros.""" def generate(self, distribution: Union[dict, list, tuple]): - shape = self.compute_shape(distribution) + shape = compute_distribution_shape(distribution) reasonings = torch.zeros(*shape) reasonings = self.generate_end_hook(reasonings) return reasonings @@ -358,7 +359,7 @@ class ZerosReasoningsInitializer(AbstractReasoningsInitializer): class OnesReasoningsInitializer(AbstractReasoningsInitializer): """Reasonings are all initialized with ones.""" def generate(self, distribution: Union[dict, list, tuple]): - shape = self.compute_shape(distribution) + shape = compute_distribution_shape(distribution) reasonings = torch.ones(*shape) reasonings = self.generate_end_hook(reasonings) return reasonings @@ -372,7 +373,7 @@ class RandomReasoningsInitializer(AbstractReasoningsInitializer): self.maximum = maximum def generate(self, distribution: Union[dict, list, tuple]): - shape = self.compute_shape(distribution) + shape = compute_distribution_shape(distribution) reasonings = torch.ones(*shape).uniform_(self.minimum, self.maximum) reasonings = self.generate_end_hook(reasonings) return reasonings @@ -381,7 +382,8 @@ class RandomReasoningsInitializer(AbstractReasoningsInitializer): class PurePositiveReasoningsInitializer(AbstractReasoningsInitializer): """Each component reasons positively for exactly one class.""" def generate(self, distribution: Union[dict, list, tuple]): - num_components, num_classes, _ = self.compute_shape(distribution) + num_components, num_classes, _ = compute_distribution_shape( + distribution) A = OneHotLabelsInitializer().generate(distribution) B = torch.zeros(num_components, num_classes) reasonings = torch.stack([A, B], dim=-1) diff --git a/prototorch/core/pooling.py b/prototorch/core/pooling.py index fab143f..3ccf3a6 100644 --- a/prototorch/core/pooling.py +++ b/prototorch/core/pooling.py @@ -82,23 +82,23 @@ def stratified_prod_pooling(values: torch.Tensor, class StratifiedSumPooling(torch.nn.Module): """Thin wrapper over the `stratified_sum_pooling` function.""" - def forward(self, values, labels): + def forward(self, values, labels): # pylint: disable=no-self-use return stratified_sum_pooling(values, labels) class StratifiedProdPooling(torch.nn.Module): """Thin wrapper over the `stratified_prod_pooling` function.""" - def forward(self, values, labels): + def forward(self, values, labels): # pylint: disable=no-self-use return stratified_prod_pooling(values, labels) class StratifiedMinPooling(torch.nn.Module): """Thin wrapper over the `stratified_min_pooling` function.""" - def forward(self, values, labels): + def forward(self, values, labels): # pylint: disable=no-self-use return stratified_min_pooling(values, labels) class StratifiedMaxPooling(torch.nn.Module): """Thin wrapper over the `stratified_max_pooling` function.""" - def forward(self, values, labels): + def forward(self, values, labels): # pylint: disable=no-self-use return stratified_max_pooling(values, labels)