From 8200e1d3d8a62a2f6b28bddd715d504469992da4 Mon Sep 17 00:00:00 2001 From: Jensun Ravichandran Date: Fri, 4 Jun 2021 22:13:36 +0200 Subject: [PATCH] [FEATURE] Allow `initialized_components` to be a dataset --- prototorch/components/components.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/prototorch/components/components.py b/prototorch/components/components.py index aa7bb70..5968da5 100644 --- a/prototorch/components/components.py +++ b/prototorch/components/components.py @@ -11,6 +11,8 @@ from prototorch.components.initializers import (ClassAwareInitializer, ZeroReasoningsInitializer) from torch.nn.parameter import Parameter +from .initializers import parse_data_arg + def get_labels_object(distribution): if isinstance(distribution, dict): @@ -56,24 +58,23 @@ class Components(torch.nn.Module): wmsg = "Arguments ignored while initializing Components" warnings.warn(wmsg) else: - self._initialize_components(initializer, num_components) + self._initialize_components(num_components, initializer) @property def num_components(self): - # return len(self._components) - return self._components.shape[0] + return len(self._components) def _register_components(self, components): self.register_parameter("_components", Parameter(components)) - def _initialize_components(self, initializer, num_components): + def _initialize_components(self, num_components, initializer): _precheck_initializer(initializer) _components = initializer.generate(num_components) self._register_components(_components) def add_components(self, - initializer=None, num=1, + initializer=None, *, initialized_components=None): if initialized_components is not None: @@ -114,7 +115,8 @@ class LabeledComponents(Components): *, initialized_components=None): if initialized_components is not None: - components, component_labels = initialized_components + components, component_labels = parse_data_arg( + initialized_components) super().__init__(initialized_components=components) self._labels = component_labels else: @@ -134,16 +136,16 @@ class LabeledComponents(Components): return_counts=True) return dict(zip(clabels.tolist(), counts.tolist())) - def _initialize_components(self, initializer, num_components): + def _initialize_components(self, num_components, initializer): if isinstance(initializer, ClassAwareInitializer): _precheck_initializer(initializer) _components = initializer.generate(num_components, self.initial_distribution) self._register_components(_components) else: - super()._initialize_components(initializer, num_components) + super()._initialize_components(num_components, initializer) - def add_components(self, initializer, distribution): + def add_components(self, distribution, initializer): _precheck_initializer(initializer) # Labels