[REFACTOR] Simplify ReasoningComponents
This commit is contained in:
parent
6ad665f8c2
commit
d2d6f31e7b
@ -17,6 +17,7 @@ from .initializers import (
|
|||||||
|
|
||||||
|
|
||||||
def validate_initializer(initializer, instanceof):
|
def validate_initializer(initializer, instanceof):
|
||||||
|
"""Check if the initializer is valid."""
|
||||||
if not isinstance(initializer, instanceof):
|
if not isinstance(initializer, instanceof):
|
||||||
emsg = f"`initializer` has to be an instance " \
|
emsg = f"`initializer` has to be an instance " \
|
||||||
f"of some subtype of {instanceof}. " \
|
f"of some subtype of {instanceof}. " \
|
||||||
@ -47,6 +48,17 @@ def removeind(ins, attr, indices):
|
|||||||
return items, mask
|
return items, mask
|
||||||
|
|
||||||
|
|
||||||
|
def get_cikwargs(init, distribution):
|
||||||
|
"""Return appropriate key-word arguments for a component initializer."""
|
||||||
|
if isinstance(init, ClassAwareCompInitializer):
|
||||||
|
cikwargs = dict(distribution=distribution)
|
||||||
|
else:
|
||||||
|
distribution = parse_distribution(distribution)
|
||||||
|
num_components = sum(distribution.values())
|
||||||
|
cikwargs = dict(num_components=num_components)
|
||||||
|
return cikwargs
|
||||||
|
|
||||||
|
|
||||||
class AbstractComponents(torch.nn.Module):
|
class AbstractComponents(torch.nn.Module):
|
||||||
"""Abstract class for all components modules."""
|
"""Abstract class for all components modules."""
|
||||||
@property
|
@property
|
||||||
@ -194,12 +206,7 @@ class LabeledComponents(AbstractComponents):
|
|||||||
AbstractComponentsInitializer)
|
AbstractComponentsInitializer)
|
||||||
assert validate_initializer(labels_initializer,
|
assert validate_initializer(labels_initializer,
|
||||||
AbstractLabelsInitializer)
|
AbstractLabelsInitializer)
|
||||||
if isinstance(components_initializer, ClassAwareCompInitializer):
|
cikwargs = get_cikwargs(components_initializer, distribution)
|
||||||
cikwargs = dict(distribution=distribution)
|
|
||||||
else:
|
|
||||||
distribution = parse_distribution(distribution)
|
|
||||||
num_components = sum(distribution.values())
|
|
||||||
cikwargs = dict(num_components=num_components)
|
|
||||||
_components, new_components = gencat(self, "_components",
|
_components, new_components = gencat(self, "_components",
|
||||||
components_initializer,
|
components_initializer,
|
||||||
**cikwargs)
|
**cikwargs)
|
||||||
@ -259,47 +266,28 @@ class ReasoningComponents(AbstractComponents):
|
|||||||
|
|
||||||
def add_components(self, distribution, components_initializer,
|
def add_components(self, distribution, components_initializer,
|
||||||
reasonings_initializer: AbstractReasoningsInitializer):
|
reasonings_initializer: AbstractReasoningsInitializer):
|
||||||
# Checks
|
"""Generate and add new components and reasonings."""
|
||||||
assert validate_initializer(components_initializer,
|
assert validate_initializer(components_initializer,
|
||||||
AbstractComponentsInitializer)
|
AbstractComponentsInitializer)
|
||||||
assert validate_initializer(reasonings_initializer,
|
assert validate_initializer(reasonings_initializer,
|
||||||
AbstractReasoningsInitializer)
|
AbstractReasoningsInitializer)
|
||||||
|
cikwargs = get_cikwargs(components_initializer, distribution)
|
||||||
distribution = parse_distribution(distribution)
|
_components, new_components = gencat(self, "_components",
|
||||||
|
components_initializer,
|
||||||
# Generate new components
|
**cikwargs)
|
||||||
if isinstance(components_initializer, ClassAwareCompInitializer):
|
_reasonings, new_reasonings = gencat(self, "_reasonings",
|
||||||
new_components = components_initializer.generate(distribution)
|
reasonings_initializer,
|
||||||
else:
|
distribution)
|
||||||
num_components = sum(distribution.values())
|
|
||||||
new_components = components_initializer.generate(num_components)
|
|
||||||
|
|
||||||
# Generate new reasonings
|
|
||||||
new_reasonings = reasonings_initializer.generate(distribution)
|
|
||||||
|
|
||||||
# Register
|
|
||||||
if hasattr(self, "_components"):
|
|
||||||
_components = torch.cat([self._components, new_components])
|
|
||||||
else:
|
|
||||||
_components = new_components
|
|
||||||
if hasattr(self, "_reasonings"):
|
|
||||||
_reasonings = torch.cat([self._reasonings, new_reasonings])
|
|
||||||
else:
|
|
||||||
_reasonings = new_reasonings
|
|
||||||
self._register_components(_components)
|
self._register_components(_components)
|
||||||
self._register_reasonings(_reasonings)
|
self._register_reasonings(_reasonings)
|
||||||
|
|
||||||
return new_components, new_reasonings
|
return new_components, new_reasonings
|
||||||
|
|
||||||
def remove_components(self, indices):
|
def remove_components(self, indices):
|
||||||
"""Remove components and labels at specified indices."""
|
"""Remove components and reasonings at specified indices."""
|
||||||
mask = torch.ones(self.num_components, dtype=torch.bool)
|
_components, mask = removeind(self, "_components", indices)
|
||||||
mask[indices] = False
|
_reasonings, mask = removeind(self, "_reasonings", indices)
|
||||||
_components = self._components[mask]
|
|
||||||
# TODO
|
|
||||||
# _reasonings = self._reasonings[mask]
|
|
||||||
self._register_components(_components)
|
self._register_components(_components)
|
||||||
# self._register_reasonings(_reasonings)
|
self._register_reasonings(_reasonings)
|
||||||
return mask
|
return mask
|
||||||
|
|
||||||
def forward(self):
|
def forward(self):
|
||||||
|
Loading…
Reference in New Issue
Block a user