diff --git a/prototorch/components/components.py b/prototorch/components/components.py index f485f5d..5e7b9d6 100644 --- a/prototorch/components/components.py +++ b/prototorch/components/components.py @@ -41,8 +41,6 @@ class Components(torch.nn.Module): initialized_components=None): super().__init__() - self.num_components = num_components - # Ignore all initialization settings if initialized_components is given. if initialized_components is not None: self._register_components(initialized_components) @@ -50,7 +48,12 @@ class Components(torch.nn.Module): wmsg = "Arguments ignored while initializing Components" warnings.warn(wmsg) else: - self._initialize_components(initializer) + self._initialize_components(initializer, num_components) + + @property + def num_components(self): + # return len(self._components) + return self._components.shape[0] def _register_components(self, components): self.register_parameter("_components", Parameter(components)) @@ -62,17 +65,31 @@ class Components(torch.nn.Module): f"You have provided: {initializer=} instead." raise TypeError(emsg) - def _initialize_components(self, initializer): + def _initialize_components(self, initializer, num_components): self._precheck_initializer(initializer) - _components = initializer.generate(self.num_components) + _components = initializer.generate(num_components) self._register_components(_components) - def increase_components(self, initializer, num=1): - self._precheck_initializer(initializer) - _new = initializer.generate(num) - _components = torch.cat([self._components, _new]) + def add_components(self, + initializer=None, + num=1, + *, + initialized_components=None): + if initialized_components is not None: + _components = torch.cat([self._components, initialized_components]) + else: + self._precheck_initializer(initializer) + _new = initializer.generate(num) + _components = torch.cat([self._components, _new]) self._register_components(_components) + def remove_components(self, indices=None): + mask = torch.ones(self.num_components, dtype=torch.bool) + mask[indices] = False + _components = self._components[mask] + self._register_components(_components) + return mask + @property def components(self): """Tensor containing the component tensors.""" @@ -101,7 +118,7 @@ class LabeledComponents(Components): self._labels = component_labels else: labels = get_labels_object(distribution) - self.distribution = labels.distribution + self.initial_distribution = labels.distribution _labels = labels.generate() super().__init__(len(_labels), initializer=initializer) self._register_labels(_labels) @@ -109,21 +126,21 @@ class LabeledComponents(Components): def _register_labels(self, labels): self.register_buffer("_labels", labels) - def _update_distribution(self, distribution): - self.distribution = [ - old + new for old, new in zip(self.distribution, distribution) - ] + @property + def distribution(self): + clabels, counts = torch.unique(self._labels, sorted=True, return_counts=True) + return dict(zip(clabels.tolist(), counts.tolist())) - def _initialize_components(self, initializer): + def _initialize_components(self, initializer, num_components): if isinstance(initializer, ClassAwareInitializer): self._precheck_initializer(initializer) - _components = initializer.generate(self.num_components, - self.distribution) + _components = initializer.generate(num_components, + self.initial_distribution) self._register_components(_components) else: - super()._initialize_components(initializer) + super()._initialize_components(initializer, num_components) - def increase_components(self, initializer, distribution=[1]): + def add_components(self, initializer, distribution=[1]): self._precheck_initializer(initializer) # Labels @@ -140,8 +157,13 @@ class LabeledComponents(Components): _components = torch.cat([self._components, _new]) self._register_components(_components) - # Housekeeping - self._update_distribution(labels.distribution) + def remove_components(self, indices=None): + # Components + mask = super().remove_components(indices) + + # Labels + _labels = self._labels[mask] + self._register_labels(_labels) @property def component_labels(self):