[REFACTOR] Simplify ReasoningComponents
This commit is contained in:
parent
6ad665f8c2
commit
d2d6f31e7b
@ -17,6 +17,7 @@ from .initializers import (
|
||||
|
||||
|
||||
def validate_initializer(initializer, instanceof):
|
||||
"""Check if the initializer is valid."""
|
||||
if not isinstance(initializer, instanceof):
|
||||
emsg = f"`initializer` has to be an instance " \
|
||||
f"of some subtype of {instanceof}. " \
|
||||
@ -47,6 +48,17 @@ def removeind(ins, attr, indices):
|
||||
return items, mask
|
||||
|
||||
|
||||
def get_cikwargs(init, distribution):
|
||||
"""Return appropriate key-word arguments for a component initializer."""
|
||||
if isinstance(init, ClassAwareCompInitializer):
|
||||
cikwargs = dict(distribution=distribution)
|
||||
else:
|
||||
distribution = parse_distribution(distribution)
|
||||
num_components = sum(distribution.values())
|
||||
cikwargs = dict(num_components=num_components)
|
||||
return cikwargs
|
||||
|
||||
|
||||
class AbstractComponents(torch.nn.Module):
|
||||
"""Abstract class for all components modules."""
|
||||
@property
|
||||
@ -194,12 +206,7 @@ class LabeledComponents(AbstractComponents):
|
||||
AbstractComponentsInitializer)
|
||||
assert validate_initializer(labels_initializer,
|
||||
AbstractLabelsInitializer)
|
||||
if isinstance(components_initializer, ClassAwareCompInitializer):
|
||||
cikwargs = dict(distribution=distribution)
|
||||
else:
|
||||
distribution = parse_distribution(distribution)
|
||||
num_components = sum(distribution.values())
|
||||
cikwargs = dict(num_components=num_components)
|
||||
cikwargs = get_cikwargs(components_initializer, distribution)
|
||||
_components, new_components = gencat(self, "_components",
|
||||
components_initializer,
|
||||
**cikwargs)
|
||||
@ -259,47 +266,28 @@ class ReasoningComponents(AbstractComponents):
|
||||
|
||||
def add_components(self, distribution, components_initializer,
|
||||
reasonings_initializer: AbstractReasoningsInitializer):
|
||||
# Checks
|
||||
"""Generate and add new components and reasonings."""
|
||||
assert validate_initializer(components_initializer,
|
||||
AbstractComponentsInitializer)
|
||||
assert validate_initializer(reasonings_initializer,
|
||||
AbstractReasoningsInitializer)
|
||||
|
||||
distribution = parse_distribution(distribution)
|
||||
|
||||
# Generate new components
|
||||
if isinstance(components_initializer, ClassAwareCompInitializer):
|
||||
new_components = components_initializer.generate(distribution)
|
||||
else:
|
||||
num_components = sum(distribution.values())
|
||||
new_components = components_initializer.generate(num_components)
|
||||
|
||||
# Generate new reasonings
|
||||
new_reasonings = reasonings_initializer.generate(distribution)
|
||||
|
||||
# Register
|
||||
if hasattr(self, "_components"):
|
||||
_components = torch.cat([self._components, new_components])
|
||||
else:
|
||||
_components = new_components
|
||||
if hasattr(self, "_reasonings"):
|
||||
_reasonings = torch.cat([self._reasonings, new_reasonings])
|
||||
else:
|
||||
_reasonings = new_reasonings
|
||||
cikwargs = get_cikwargs(components_initializer, distribution)
|
||||
_components, new_components = gencat(self, "_components",
|
||||
components_initializer,
|
||||
**cikwargs)
|
||||
_reasonings, new_reasonings = gencat(self, "_reasonings",
|
||||
reasonings_initializer,
|
||||
distribution)
|
||||
self._register_components(_components)
|
||||
self._register_reasonings(_reasonings)
|
||||
|
||||
return new_components, new_reasonings
|
||||
|
||||
def remove_components(self, indices):
|
||||
"""Remove components and labels at specified indices."""
|
||||
mask = torch.ones(self.num_components, dtype=torch.bool)
|
||||
mask[indices] = False
|
||||
_components = self._components[mask]
|
||||
# TODO
|
||||
# _reasonings = self._reasonings[mask]
|
||||
"""Remove components and reasonings at specified indices."""
|
||||
_components, mask = removeind(self, "_components", indices)
|
||||
_reasonings, mask = removeind(self, "_reasonings", indices)
|
||||
self._register_components(_components)
|
||||
# self._register_reasonings(_reasonings)
|
||||
self._register_reasonings(_reasonings)
|
||||
return mask
|
||||
|
||||
def forward(self):
|
||||
|
Loading…
Reference in New Issue
Block a user