From c0c0044a425d0b20109ea8ab87fc01b7a4634a26 Mon Sep 17 00:00:00 2001 From: Jensun Ravichandran Date: Fri, 11 Jun 2021 14:52:09 +0200 Subject: [PATCH] [REFACTOR] Remove CustomLabelsInitializer --- prototorch/components/components.py | 7 ++++--- prototorch/components/initializers.py | 19 +++++-------------- prototorch/modules/competitions.py | 3 +++ 3 files changed, 12 insertions(+), 17 deletions(-) diff --git a/prototorch/components/components.py b/prototorch/components/components.py index 5968da5..6d001f7 100644 --- a/prototorch/components/components.py +++ b/prototorch/components/components.py @@ -5,7 +5,6 @@ import warnings import torch from prototorch.components.initializers import (ClassAwareInitializer, ComponentsInitializer, - CustomLabelsInitializer, EqualLabelsInitializer, UnequalLabelsInitializer, ZeroReasoningsInitializer) @@ -21,7 +20,9 @@ def get_labels_object(distribution): distribution["num_classes"], distribution["prototypes_per_class"]) else: - labels = CustomLabelsInitializer(distribution) + clabels = list(distribution.keys()) + dist = list(distribution.values()) + labels = UnequalLabelsInitializer(dist, clabels) elif isinstance(distribution, tuple): num_classes, prototypes_per_class = distribution labels = EqualLabelsInitializer(num_classes, prototypes_per_class) @@ -156,7 +157,7 @@ class LabeledComponents(Components): # Components if isinstance(initializer, ClassAwareInitializer): - _new = initializer.generate(len(new_labels), labels.distribution) + _new = initializer.generate(len(new_labels), distribution) else: _new = initializer.generate(len(new_labels)) _components = torch.cat([self._components, _new]) diff --git a/prototorch/components/initializers.py b/prototorch/components/initializers.py index 4582d6d..d05c6c7 100644 --- a/prototorch/components/initializers.py +++ b/prototorch/components/initializers.py @@ -174,19 +174,17 @@ class LabelsInitializer: class UnequalLabelsInitializer(LabelsInitializer): - def __init__(self, dist): + def __init__(self, dist, clabels=None): self.dist = dist + self.clabels = clabels or range(len(self.dist)) @property def distribution(self): return self.dist - def generate(self, clabels=None, dist=None): - if not clabels: - clabels = range(len(self.dist)) - if not dist: - dist = self.dist - targets = list(chain(*[[i] * n for i, n in zip(clabels, dist)])) + def generate(self): + targets = list( + chain(*[[i] * n for i, n in zip(self.clabels, self.dist)])) return torch.LongTensor(targets) @@ -203,13 +201,6 @@ class EqualLabelsInitializer(LabelsInitializer): return torch.arange(self.classes).repeat(self.per_class, 1).T.flatten() -class CustomLabelsInitializer(UnequalLabelsInitializer): - def generate(self): - clabels = list(self.dist.keys()) - dist = list(self.dist.values()) - return super().generate(clabels, dist) - - # Reasonings class ReasoningsInitializer: def generate(self, length): diff --git a/prototorch/modules/competitions.py b/prototorch/modules/competitions.py index a15631a..585c5d6 100644 --- a/prototorch/modules/competitions.py +++ b/prototorch/modules/competitions.py @@ -10,6 +10,7 @@ class WTAC(torch.nn.Module): Thin wrapper over the `wtac` function. """ + def forward(self, distances, labels): return wtac(distances, labels) @@ -20,6 +21,7 @@ class LTAC(torch.nn.Module): Thin wrapper over the `wtac` function. """ + def forward(self, probs, labels): return wtac(-1.0 * probs, labels) @@ -30,6 +32,7 @@ class KNNC(torch.nn.Module): Thin wrapper over the `knnc` function. """ + def __init__(self, k=1, **kwargs): super().__init__(**kwargs) self.k = k