[REFACTOR] Remove CustomLabelsInitializer
This commit is contained in:
parent
47d7f5831f
commit
c0c0044a42
@ -5,7 +5,6 @@ import warnings
|
|||||||
import torch
|
import torch
|
||||||
from prototorch.components.initializers import (ClassAwareInitializer,
|
from prototorch.components.initializers import (ClassAwareInitializer,
|
||||||
ComponentsInitializer,
|
ComponentsInitializer,
|
||||||
CustomLabelsInitializer,
|
|
||||||
EqualLabelsInitializer,
|
EqualLabelsInitializer,
|
||||||
UnequalLabelsInitializer,
|
UnequalLabelsInitializer,
|
||||||
ZeroReasoningsInitializer)
|
ZeroReasoningsInitializer)
|
||||||
@ -21,7 +20,9 @@ def get_labels_object(distribution):
|
|||||||
distribution["num_classes"],
|
distribution["num_classes"],
|
||||||
distribution["prototypes_per_class"])
|
distribution["prototypes_per_class"])
|
||||||
else:
|
else:
|
||||||
labels = CustomLabelsInitializer(distribution)
|
clabels = list(distribution.keys())
|
||||||
|
dist = list(distribution.values())
|
||||||
|
labels = UnequalLabelsInitializer(dist, clabels)
|
||||||
elif isinstance(distribution, tuple):
|
elif isinstance(distribution, tuple):
|
||||||
num_classes, prototypes_per_class = distribution
|
num_classes, prototypes_per_class = distribution
|
||||||
labels = EqualLabelsInitializer(num_classes, prototypes_per_class)
|
labels = EqualLabelsInitializer(num_classes, prototypes_per_class)
|
||||||
@ -156,7 +157,7 @@ class LabeledComponents(Components):
|
|||||||
|
|
||||||
# Components
|
# Components
|
||||||
if isinstance(initializer, ClassAwareInitializer):
|
if isinstance(initializer, ClassAwareInitializer):
|
||||||
_new = initializer.generate(len(new_labels), labels.distribution)
|
_new = initializer.generate(len(new_labels), distribution)
|
||||||
else:
|
else:
|
||||||
_new = initializer.generate(len(new_labels))
|
_new = initializer.generate(len(new_labels))
|
||||||
_components = torch.cat([self._components, _new])
|
_components = torch.cat([self._components, _new])
|
||||||
|
@ -174,19 +174,17 @@ class LabelsInitializer:
|
|||||||
|
|
||||||
|
|
||||||
class UnequalLabelsInitializer(LabelsInitializer):
|
class UnequalLabelsInitializer(LabelsInitializer):
|
||||||
def __init__(self, dist):
|
def __init__(self, dist, clabels=None):
|
||||||
self.dist = dist
|
self.dist = dist
|
||||||
|
self.clabels = clabels or range(len(self.dist))
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def distribution(self):
|
def distribution(self):
|
||||||
return self.dist
|
return self.dist
|
||||||
|
|
||||||
def generate(self, clabels=None, dist=None):
|
def generate(self):
|
||||||
if not clabels:
|
targets = list(
|
||||||
clabels = range(len(self.dist))
|
chain(*[[i] * n for i, n in zip(self.clabels, self.dist)]))
|
||||||
if not dist:
|
|
||||||
dist = self.dist
|
|
||||||
targets = list(chain(*[[i] * n for i, n in zip(clabels, dist)]))
|
|
||||||
return torch.LongTensor(targets)
|
return torch.LongTensor(targets)
|
||||||
|
|
||||||
|
|
||||||
@ -203,13 +201,6 @@ class EqualLabelsInitializer(LabelsInitializer):
|
|||||||
return torch.arange(self.classes).repeat(self.per_class, 1).T.flatten()
|
return torch.arange(self.classes).repeat(self.per_class, 1).T.flatten()
|
||||||
|
|
||||||
|
|
||||||
class CustomLabelsInitializer(UnequalLabelsInitializer):
|
|
||||||
def generate(self):
|
|
||||||
clabels = list(self.dist.keys())
|
|
||||||
dist = list(self.dist.values())
|
|
||||||
return super().generate(clabels, dist)
|
|
||||||
|
|
||||||
|
|
||||||
# Reasonings
|
# Reasonings
|
||||||
class ReasoningsInitializer:
|
class ReasoningsInitializer:
|
||||||
def generate(self, length):
|
def generate(self, length):
|
||||||
|
@ -10,6 +10,7 @@ class WTAC(torch.nn.Module):
|
|||||||
Thin wrapper over the `wtac` function.
|
Thin wrapper over the `wtac` function.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def forward(self, distances, labels):
|
def forward(self, distances, labels):
|
||||||
return wtac(distances, labels)
|
return wtac(distances, labels)
|
||||||
|
|
||||||
@ -20,6 +21,7 @@ class LTAC(torch.nn.Module):
|
|||||||
Thin wrapper over the `wtac` function.
|
Thin wrapper over the `wtac` function.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def forward(self, probs, labels):
|
def forward(self, probs, labels):
|
||||||
return wtac(-1.0 * probs, labels)
|
return wtac(-1.0 * probs, labels)
|
||||||
|
|
||||||
@ -30,6 +32,7 @@ class KNNC(torch.nn.Module):
|
|||||||
Thin wrapper over the `knnc` function.
|
Thin wrapper over the `knnc` function.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, k=1, **kwargs):
|
def __init__(self, k=1, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self.k = k
|
self.k = k
|
||||||
|
Loading…
Reference in New Issue
Block a user