[QA] Fix for "no-self-use" (R0201)
This commit is contained in:
parent
e8e803e8ef
commit
b1568a550a
@ -32,6 +32,14 @@ def get_labels_object(distribution):
|
|||||||
return labels
|
return labels
|
||||||
|
|
||||||
|
|
||||||
|
def _precheck_initializer(self, initializer):
|
||||||
|
if not isinstance(initializer, ComponentsInitializer):
|
||||||
|
emsg = f"`initializer` has to be some subtype of " \
|
||||||
|
f"{ComponentsInitializer}. " \
|
||||||
|
f"You have provided: {initializer=} instead."
|
||||||
|
raise TypeError(emsg)
|
||||||
|
|
||||||
|
|
||||||
class Components(torch.nn.Module):
|
class Components(torch.nn.Module):
|
||||||
"""Components is a set of learnable Tensors."""
|
"""Components is a set of learnable Tensors."""
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
@ -58,15 +66,8 @@ class Components(torch.nn.Module):
|
|||||||
def _register_components(self, components):
|
def _register_components(self, components):
|
||||||
self.register_parameter("_components", Parameter(components))
|
self.register_parameter("_components", Parameter(components))
|
||||||
|
|
||||||
def _precheck_initializer(self, initializer):
|
|
||||||
if not isinstance(initializer, ComponentsInitializer):
|
|
||||||
emsg = f"`initializer` has to be some subtype of " \
|
|
||||||
f"{ComponentsInitializer}. " \
|
|
||||||
f"You have provided: {initializer=} instead."
|
|
||||||
raise TypeError(emsg)
|
|
||||||
|
|
||||||
def _initialize_components(self, initializer, num_components):
|
def _initialize_components(self, initializer, num_components):
|
||||||
self._precheck_initializer(initializer)
|
_precheck_initializer(initializer)
|
||||||
_components = initializer.generate(num_components)
|
_components = initializer.generate(num_components)
|
||||||
self._register_components(_components)
|
self._register_components(_components)
|
||||||
|
|
||||||
@ -78,7 +79,7 @@ class Components(torch.nn.Module):
|
|||||||
if initialized_components is not None:
|
if initialized_components is not None:
|
||||||
_components = torch.cat([self._components, initialized_components])
|
_components = torch.cat([self._components, initialized_components])
|
||||||
else:
|
else:
|
||||||
self._precheck_initializer(initializer)
|
_precheck_initializer(initializer)
|
||||||
_new = initializer.generate(num)
|
_new = initializer.generate(num)
|
||||||
_components = torch.cat([self._components, _new])
|
_components = torch.cat([self._components, _new])
|
||||||
self._register_components(_components)
|
self._register_components(_components)
|
||||||
@ -128,12 +129,14 @@ class LabeledComponents(Components):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def distribution(self):
|
def distribution(self):
|
||||||
clabels, counts = torch.unique(self._labels, sorted=True, return_counts=True)
|
clabels, counts = torch.unique(self._labels,
|
||||||
|
sorted=True,
|
||||||
|
return_counts=True)
|
||||||
return dict(zip(clabels.tolist(), counts.tolist()))
|
return dict(zip(clabels.tolist(), counts.tolist()))
|
||||||
|
|
||||||
def _initialize_components(self, initializer, num_components):
|
def _initialize_components(self, initializer, num_components):
|
||||||
if isinstance(initializer, ClassAwareInitializer):
|
if isinstance(initializer, ClassAwareInitializer):
|
||||||
self._precheck_initializer(initializer)
|
_precheck_initializer(initializer)
|
||||||
_components = initializer.generate(num_components,
|
_components = initializer.generate(num_components,
|
||||||
self.initial_distribution)
|
self.initial_distribution)
|
||||||
self._register_components(_components)
|
self._register_components(_components)
|
||||||
@ -141,7 +144,7 @@ class LabeledComponents(Components):
|
|||||||
super()._initialize_components(initializer, num_components)
|
super()._initialize_components(initializer, num_components)
|
||||||
|
|
||||||
def add_components(self, initializer, distribution):
|
def add_components(self, initializer, distribution):
|
||||||
self._precheck_initializer(initializer)
|
_precheck_initializer(initializer)
|
||||||
|
|
||||||
# Labels
|
# Labels
|
||||||
labels = get_labels_object(distribution)
|
labels = get_labels_object(distribution)
|
||||||
|
Loading…
Reference in New Issue
Block a user