[REFACTOR] Remove CustomLabelsInitializer

This commit is contained in:
Jensun Ravichandran 2021-06-11 14:52:09 +02:00
parent 47d7f5831f
commit c0c0044a42
3 changed files with 12 additions and 17 deletions

View File

@ -5,7 +5,6 @@ import warnings
import torch import torch
from prototorch.components.initializers import (ClassAwareInitializer, from prototorch.components.initializers import (ClassAwareInitializer,
ComponentsInitializer, ComponentsInitializer,
CustomLabelsInitializer,
EqualLabelsInitializer, EqualLabelsInitializer,
UnequalLabelsInitializer, UnequalLabelsInitializer,
ZeroReasoningsInitializer) ZeroReasoningsInitializer)
@ -21,7 +20,9 @@ def get_labels_object(distribution):
distribution["num_classes"], distribution["num_classes"],
distribution["prototypes_per_class"]) distribution["prototypes_per_class"])
else: else:
labels = CustomLabelsInitializer(distribution) clabels = list(distribution.keys())
dist = list(distribution.values())
labels = UnequalLabelsInitializer(dist, clabels)
elif isinstance(distribution, tuple): elif isinstance(distribution, tuple):
num_classes, prototypes_per_class = distribution num_classes, prototypes_per_class = distribution
labels = EqualLabelsInitializer(num_classes, prototypes_per_class) labels = EqualLabelsInitializer(num_classes, prototypes_per_class)
@ -156,7 +157,7 @@ class LabeledComponents(Components):
# Components # Components
if isinstance(initializer, ClassAwareInitializer): if isinstance(initializer, ClassAwareInitializer):
_new = initializer.generate(len(new_labels), labels.distribution) _new = initializer.generate(len(new_labels), distribution)
else: else:
_new = initializer.generate(len(new_labels)) _new = initializer.generate(len(new_labels))
_components = torch.cat([self._components, _new]) _components = torch.cat([self._components, _new])

View File

@ -174,19 +174,17 @@ class LabelsInitializer:
class UnequalLabelsInitializer(LabelsInitializer): class UnequalLabelsInitializer(LabelsInitializer):
def __init__(self, dist): def __init__(self, dist, clabels=None):
self.dist = dist self.dist = dist
self.clabels = clabels or range(len(self.dist))
@property @property
def distribution(self): def distribution(self):
return self.dist return self.dist
def generate(self, clabels=None, dist=None): def generate(self):
if not clabels: targets = list(
clabels = range(len(self.dist)) chain(*[[i] * n for i, n in zip(self.clabels, self.dist)]))
if not dist:
dist = self.dist
targets = list(chain(*[[i] * n for i, n in zip(clabels, dist)]))
return torch.LongTensor(targets) return torch.LongTensor(targets)
@ -203,13 +201,6 @@ class EqualLabelsInitializer(LabelsInitializer):
return torch.arange(self.classes).repeat(self.per_class, 1).T.flatten() return torch.arange(self.classes).repeat(self.per_class, 1).T.flatten()
class CustomLabelsInitializer(UnequalLabelsInitializer):
def generate(self):
clabels = list(self.dist.keys())
dist = list(self.dist.values())
return super().generate(clabels, dist)
# Reasonings # Reasonings
class ReasoningsInitializer: class ReasoningsInitializer:
def generate(self, length): def generate(self, length):

View File

@ -10,6 +10,7 @@ class WTAC(torch.nn.Module):
Thin wrapper over the `wtac` function. Thin wrapper over the `wtac` function.
""" """
def forward(self, distances, labels): def forward(self, distances, labels):
return wtac(distances, labels) return wtac(distances, labels)
@ -20,6 +21,7 @@ class LTAC(torch.nn.Module):
Thin wrapper over the `wtac` function. Thin wrapper over the `wtac` function.
""" """
def forward(self, probs, labels): def forward(self, probs, labels):
return wtac(-1.0 * probs, labels) return wtac(-1.0 * probs, labels)
@ -30,6 +32,7 @@ class KNNC(torch.nn.Module):
Thin wrapper over the `knnc` function. Thin wrapper over the `knnc` function.
""" """
def __init__(self, k=1, **kwargs): def __init__(self, k=1, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.k = k self.k = k