feat: warn user when component counts do not match
This commit is contained in:
parent
08b3f9bbb9
commit
ed5b9b6c62
@ -26,11 +26,18 @@ class LiteralCompInitializer(AbstractComponentsInitializer):
|
|||||||
Use this to 'generate' pre-initialized components elsewhere.
|
Use this to 'generate' pre-initialized components elsewhere.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, components):
|
def __init__(self, components):
|
||||||
self.components = components
|
self.components = components
|
||||||
|
|
||||||
def generate(self, num_components: int = 0):
|
def generate(self, num_components: int = 0):
|
||||||
"""Ignore `num_components` and simply return `self.components`."""
|
"""Ignore `num_components` and simply return `self.components`."""
|
||||||
|
provided_num_components = len(self.components)
|
||||||
|
if provided_num_components != num_components:
|
||||||
|
wmsg = f"The number of components ({provided_num_components}) " \
|
||||||
|
f"provided to {self.__class__.__name__} " \
|
||||||
|
f"does not match the expected number ({num_components})."
|
||||||
|
warnings.warn(wmsg)
|
||||||
if not isinstance(self.components, torch.Tensor):
|
if not isinstance(self.components, torch.Tensor):
|
||||||
wmsg = f"Converting components to {torch.Tensor}..."
|
wmsg = f"Converting components to {torch.Tensor}..."
|
||||||
warnings.warn(wmsg)
|
warnings.warn(wmsg)
|
||||||
@ -40,6 +47,7 @@ class LiteralCompInitializer(AbstractComponentsInitializer):
|
|||||||
|
|
||||||
class ShapeAwareCompInitializer(AbstractComponentsInitializer):
|
class ShapeAwareCompInitializer(AbstractComponentsInitializer):
|
||||||
"""Abstract class for all dimension-aware components initializers."""
|
"""Abstract class for all dimension-aware components initializers."""
|
||||||
|
|
||||||
def __init__(self, shape: Union[Iterable, int]):
|
def __init__(self, shape: Union[Iterable, int]):
|
||||||
if isinstance(shape, Iterable):
|
if isinstance(shape, Iterable):
|
||||||
self.component_shape = tuple(shape)
|
self.component_shape = tuple(shape)
|
||||||
@ -53,6 +61,7 @@ class ShapeAwareCompInitializer(AbstractComponentsInitializer):
|
|||||||
|
|
||||||
class ZerosCompInitializer(ShapeAwareCompInitializer):
|
class ZerosCompInitializer(ShapeAwareCompInitializer):
|
||||||
"""Generate zeros corresponding to the components shape."""
|
"""Generate zeros corresponding to the components shape."""
|
||||||
|
|
||||||
def generate(self, num_components: int):
|
def generate(self, num_components: int):
|
||||||
components = torch.zeros((num_components, ) + self.component_shape)
|
components = torch.zeros((num_components, ) + self.component_shape)
|
||||||
return components
|
return components
|
||||||
@ -60,6 +69,7 @@ class ZerosCompInitializer(ShapeAwareCompInitializer):
|
|||||||
|
|
||||||
class OnesCompInitializer(ShapeAwareCompInitializer):
|
class OnesCompInitializer(ShapeAwareCompInitializer):
|
||||||
"""Generate ones corresponding to the components shape."""
|
"""Generate ones corresponding to the components shape."""
|
||||||
|
|
||||||
def generate(self, num_components: int):
|
def generate(self, num_components: int):
|
||||||
components = torch.ones((num_components, ) + self.component_shape)
|
components = torch.ones((num_components, ) + self.component_shape)
|
||||||
return components
|
return components
|
||||||
@ -67,6 +77,7 @@ class OnesCompInitializer(ShapeAwareCompInitializer):
|
|||||||
|
|
||||||
class FillValueCompInitializer(OnesCompInitializer):
|
class FillValueCompInitializer(OnesCompInitializer):
|
||||||
"""Generate components with the provided `fill_value`."""
|
"""Generate components with the provided `fill_value`."""
|
||||||
|
|
||||||
def __init__(self, shape, fill_value: float = 1.0):
|
def __init__(self, shape, fill_value: float = 1.0):
|
||||||
super().__init__(shape)
|
super().__init__(shape)
|
||||||
self.fill_value = fill_value
|
self.fill_value = fill_value
|
||||||
@ -79,6 +90,7 @@ class FillValueCompInitializer(OnesCompInitializer):
|
|||||||
|
|
||||||
class UniformCompInitializer(OnesCompInitializer):
|
class UniformCompInitializer(OnesCompInitializer):
|
||||||
"""Generate components by sampling from a continuous uniform distribution."""
|
"""Generate components by sampling from a continuous uniform distribution."""
|
||||||
|
|
||||||
def __init__(self, shape, minimum=0.0, maximum=1.0, scale=1.0):
|
def __init__(self, shape, minimum=0.0, maximum=1.0, scale=1.0):
|
||||||
super().__init__(shape)
|
super().__init__(shape)
|
||||||
self.minimum = minimum
|
self.minimum = minimum
|
||||||
@ -93,6 +105,7 @@ class UniformCompInitializer(OnesCompInitializer):
|
|||||||
|
|
||||||
class RandomNormalCompInitializer(OnesCompInitializer):
|
class RandomNormalCompInitializer(OnesCompInitializer):
|
||||||
"""Generate components by sampling from a standard normal distribution."""
|
"""Generate components by sampling from a standard normal distribution."""
|
||||||
|
|
||||||
def __init__(self, shape, shift=0.0, scale=1.0):
|
def __init__(self, shape, shift=0.0, scale=1.0):
|
||||||
super().__init__(shape)
|
super().__init__(shape)
|
||||||
self.shift = shift
|
self.shift = shift
|
||||||
@ -113,6 +126,7 @@ class AbstractDataAwareCompInitializer(AbstractComponentsInitializer):
|
|||||||
`data` has to be a torch tensor.
|
`data` has to be a torch tensor.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
data: torch.Tensor,
|
data: torch.Tensor,
|
||||||
noise: float = 0.0,
|
noise: float = 0.0,
|
||||||
@ -137,6 +151,7 @@ class AbstractDataAwareCompInitializer(AbstractComponentsInitializer):
|
|||||||
|
|
||||||
class DataAwareCompInitializer(AbstractDataAwareCompInitializer):
|
class DataAwareCompInitializer(AbstractDataAwareCompInitializer):
|
||||||
"""'Generate' the components from the provided data."""
|
"""'Generate' the components from the provided data."""
|
||||||
|
|
||||||
def generate(self, num_components: int = 0):
|
def generate(self, num_components: int = 0):
|
||||||
"""Ignore `num_components` and simply return transformed `self.data`."""
|
"""Ignore `num_components` and simply return transformed `self.data`."""
|
||||||
components = self.generate_end_hook(self.data)
|
components = self.generate_end_hook(self.data)
|
||||||
@ -145,6 +160,7 @@ class DataAwareCompInitializer(AbstractDataAwareCompInitializer):
|
|||||||
|
|
||||||
class SelectionCompInitializer(AbstractDataAwareCompInitializer):
|
class SelectionCompInitializer(AbstractDataAwareCompInitializer):
|
||||||
"""Generate components by uniformly sampling from the provided data."""
|
"""Generate components by uniformly sampling from the provided data."""
|
||||||
|
|
||||||
def generate(self, num_components: int):
|
def generate(self, num_components: int):
|
||||||
indices = torch.LongTensor(num_components).random_(0, len(self.data))
|
indices = torch.LongTensor(num_components).random_(0, len(self.data))
|
||||||
samples = self.data[indices]
|
samples = self.data[indices]
|
||||||
@ -154,6 +170,7 @@ class SelectionCompInitializer(AbstractDataAwareCompInitializer):
|
|||||||
|
|
||||||
class MeanCompInitializer(AbstractDataAwareCompInitializer):
|
class MeanCompInitializer(AbstractDataAwareCompInitializer):
|
||||||
"""Generate components by computing the mean of the provided data."""
|
"""Generate components by computing the mean of the provided data."""
|
||||||
|
|
||||||
def generate(self, num_components: int):
|
def generate(self, num_components: int):
|
||||||
mean = self.data.mean(dim=0)
|
mean = self.data.mean(dim=0)
|
||||||
repeat_dim = [num_components] + [1] * len(mean.shape)
|
repeat_dim = [num_components] + [1] * len(mean.shape)
|
||||||
@ -172,6 +189,7 @@ class AbstractClassAwareCompInitializer(AbstractComponentsInitializer):
|
|||||||
target tensors.
|
target tensors.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
data,
|
data,
|
||||||
noise: float = 0.0,
|
noise: float = 0.0,
|
||||||
@ -199,6 +217,7 @@ class AbstractClassAwareCompInitializer(AbstractComponentsInitializer):
|
|||||||
|
|
||||||
class ClassAwareCompInitializer(AbstractClassAwareCompInitializer):
|
class ClassAwareCompInitializer(AbstractClassAwareCompInitializer):
|
||||||
"""'Generate' components from provided data and requested distribution."""
|
"""'Generate' components from provided data and requested distribution."""
|
||||||
|
|
||||||
def generate(self, distribution: Union[dict, list, tuple]):
|
def generate(self, distribution: Union[dict, list, tuple]):
|
||||||
"""Ignore `distribution` and simply return transformed `self.data`."""
|
"""Ignore `distribution` and simply return transformed `self.data`."""
|
||||||
components = self.generate_end_hook(self.data)
|
components = self.generate_end_hook(self.data)
|
||||||
@ -207,6 +226,7 @@ class ClassAwareCompInitializer(AbstractClassAwareCompInitializer):
|
|||||||
|
|
||||||
class AbstractStratifiedCompInitializer(AbstractClassAwareCompInitializer):
|
class AbstractStratifiedCompInitializer(AbstractClassAwareCompInitializer):
|
||||||
"""Abstract class for all stratified components initializers."""
|
"""Abstract class for all stratified components initializers."""
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def subinit_type(self) -> Type[AbstractDataAwareCompInitializer]:
|
def subinit_type(self) -> Type[AbstractDataAwareCompInitializer]:
|
||||||
@ -231,6 +251,7 @@ class AbstractStratifiedCompInitializer(AbstractClassAwareCompInitializer):
|
|||||||
|
|
||||||
class StratifiedSelectionCompInitializer(AbstractStratifiedCompInitializer):
|
class StratifiedSelectionCompInitializer(AbstractStratifiedCompInitializer):
|
||||||
"""Generate components using stratified sampling from the provided data."""
|
"""Generate components using stratified sampling from the provided data."""
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def subinit_type(self):
|
def subinit_type(self):
|
||||||
return SelectionCompInitializer
|
return SelectionCompInitializer
|
||||||
@ -238,6 +259,7 @@ class StratifiedSelectionCompInitializer(AbstractStratifiedCompInitializer):
|
|||||||
|
|
||||||
class StratifiedMeanCompInitializer(AbstractStratifiedCompInitializer):
|
class StratifiedMeanCompInitializer(AbstractStratifiedCompInitializer):
|
||||||
"""Generate components at stratified means of the provided data."""
|
"""Generate components at stratified means of the provided data."""
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def subinit_type(self):
|
def subinit_type(self):
|
||||||
return MeanCompInitializer
|
return MeanCompInitializer
|
||||||
@ -246,6 +268,7 @@ class StratifiedMeanCompInitializer(AbstractStratifiedCompInitializer):
|
|||||||
# Labels
|
# Labels
|
||||||
class AbstractLabelsInitializer(ABC):
|
class AbstractLabelsInitializer(ABC):
|
||||||
"""Abstract class for all labels initializers."""
|
"""Abstract class for all labels initializers."""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def generate(self, distribution: Union[dict, list, tuple]):
|
def generate(self, distribution: Union[dict, list, tuple]):
|
||||||
...
|
...
|
||||||
@ -257,6 +280,7 @@ class LiteralLabelsInitializer(AbstractLabelsInitializer):
|
|||||||
Use this to 'generate' pre-initialized labels elsewhere.
|
Use this to 'generate' pre-initialized labels elsewhere.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, labels):
|
def __init__(self, labels):
|
||||||
self.labels = labels
|
self.labels = labels
|
||||||
|
|
||||||
@ -275,6 +299,7 @@ class LiteralLabelsInitializer(AbstractLabelsInitializer):
|
|||||||
|
|
||||||
class DataAwareLabelsInitializer(AbstractLabelsInitializer):
|
class DataAwareLabelsInitializer(AbstractLabelsInitializer):
|
||||||
"""'Generate' the labels from a torch Dataset."""
|
"""'Generate' the labels from a torch Dataset."""
|
||||||
|
|
||||||
def __init__(self, data):
|
def __init__(self, data):
|
||||||
self.data, self.targets = parse_data_arg(data)
|
self.data, self.targets = parse_data_arg(data)
|
||||||
|
|
||||||
@ -285,6 +310,7 @@ class DataAwareLabelsInitializer(AbstractLabelsInitializer):
|
|||||||
|
|
||||||
class LabelsInitializer(AbstractLabelsInitializer):
|
class LabelsInitializer(AbstractLabelsInitializer):
|
||||||
"""Generate labels from `distribution`."""
|
"""Generate labels from `distribution`."""
|
||||||
|
|
||||||
def generate(self, distribution: Union[dict, list, tuple]):
|
def generate(self, distribution: Union[dict, list, tuple]):
|
||||||
distribution = parse_distribution(distribution)
|
distribution = parse_distribution(distribution)
|
||||||
labels_list = []
|
labels_list = []
|
||||||
@ -296,6 +322,7 @@ class LabelsInitializer(AbstractLabelsInitializer):
|
|||||||
|
|
||||||
class OneHotLabelsInitializer(LabelsInitializer):
|
class OneHotLabelsInitializer(LabelsInitializer):
|
||||||
"""Generate one-hot-encoded labels from `distribution`."""
|
"""Generate one-hot-encoded labels from `distribution`."""
|
||||||
|
|
||||||
def generate(self, distribution: Union[dict, list, tuple]):
|
def generate(self, distribution: Union[dict, list, tuple]):
|
||||||
distribution = parse_distribution(distribution)
|
distribution = parse_distribution(distribution)
|
||||||
num_classes = len(distribution.keys())
|
num_classes = len(distribution.keys())
|
||||||
@ -314,6 +341,7 @@ def compute_distribution_shape(distribution):
|
|||||||
|
|
||||||
class AbstractReasoningsInitializer(ABC):
|
class AbstractReasoningsInitializer(ABC):
|
||||||
"""Abstract class for all reasonings initializers."""
|
"""Abstract class for all reasonings initializers."""
|
||||||
|
|
||||||
def __init__(self, components_first: bool = True):
|
def __init__(self, components_first: bool = True):
|
||||||
self.components_first = components_first
|
self.components_first = components_first
|
||||||
|
|
||||||
@ -334,6 +362,7 @@ class LiteralReasoningsInitializer(AbstractReasoningsInitializer):
|
|||||||
Use this to 'generate' pre-initialized reasonings elsewhere.
|
Use this to 'generate' pre-initialized reasonings elsewhere.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, reasonings, **kwargs):
|
def __init__(self, reasonings, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self.reasonings = reasonings
|
self.reasonings = reasonings
|
||||||
@ -351,6 +380,7 @@ class LiteralReasoningsInitializer(AbstractReasoningsInitializer):
|
|||||||
|
|
||||||
class ZerosReasoningsInitializer(AbstractReasoningsInitializer):
|
class ZerosReasoningsInitializer(AbstractReasoningsInitializer):
|
||||||
"""Reasonings are all initialized with zeros."""
|
"""Reasonings are all initialized with zeros."""
|
||||||
|
|
||||||
def generate(self, distribution: Union[dict, list, tuple]):
|
def generate(self, distribution: Union[dict, list, tuple]):
|
||||||
shape = compute_distribution_shape(distribution)
|
shape = compute_distribution_shape(distribution)
|
||||||
reasonings = torch.zeros(*shape)
|
reasonings = torch.zeros(*shape)
|
||||||
@ -360,6 +390,7 @@ class ZerosReasoningsInitializer(AbstractReasoningsInitializer):
|
|||||||
|
|
||||||
class OnesReasoningsInitializer(AbstractReasoningsInitializer):
|
class OnesReasoningsInitializer(AbstractReasoningsInitializer):
|
||||||
"""Reasonings are all initialized with ones."""
|
"""Reasonings are all initialized with ones."""
|
||||||
|
|
||||||
def generate(self, distribution: Union[dict, list, tuple]):
|
def generate(self, distribution: Union[dict, list, tuple]):
|
||||||
shape = compute_distribution_shape(distribution)
|
shape = compute_distribution_shape(distribution)
|
||||||
reasonings = torch.ones(*shape)
|
reasonings = torch.ones(*shape)
|
||||||
@ -369,6 +400,7 @@ class OnesReasoningsInitializer(AbstractReasoningsInitializer):
|
|||||||
|
|
||||||
class RandomReasoningsInitializer(AbstractReasoningsInitializer):
|
class RandomReasoningsInitializer(AbstractReasoningsInitializer):
|
||||||
"""Reasonings are randomly initialized."""
|
"""Reasonings are randomly initialized."""
|
||||||
|
|
||||||
def __init__(self, minimum=0.4, maximum=0.6, **kwargs):
|
def __init__(self, minimum=0.4, maximum=0.6, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self.minimum = minimum
|
self.minimum = minimum
|
||||||
@ -383,6 +415,7 @@ class RandomReasoningsInitializer(AbstractReasoningsInitializer):
|
|||||||
|
|
||||||
class PurePositiveReasoningsInitializer(AbstractReasoningsInitializer):
|
class PurePositiveReasoningsInitializer(AbstractReasoningsInitializer):
|
||||||
"""Each component reasons positively for exactly one class."""
|
"""Each component reasons positively for exactly one class."""
|
||||||
|
|
||||||
def generate(self, distribution: Union[dict, list, tuple]):
|
def generate(self, distribution: Union[dict, list, tuple]):
|
||||||
num_components, num_classes, _ = compute_distribution_shape(
|
num_components, num_classes, _ = compute_distribution_shape(
|
||||||
distribution)
|
distribution)
|
||||||
@ -401,6 +434,7 @@ class AbstractTransformInitializer(ABC):
|
|||||||
|
|
||||||
class AbstractLinearTransformInitializer(AbstractTransformInitializer):
|
class AbstractLinearTransformInitializer(AbstractTransformInitializer):
|
||||||
"""Abstract class for all linear transform initializers."""
|
"""Abstract class for all linear transform initializers."""
|
||||||
|
|
||||||
def __init__(self, out_dim_first: bool = False):
|
def __init__(self, out_dim_first: bool = False):
|
||||||
self.out_dim_first = out_dim_first
|
self.out_dim_first = out_dim_first
|
||||||
|
|
||||||
@ -417,6 +451,7 @@ class AbstractLinearTransformInitializer(AbstractTransformInitializer):
|
|||||||
|
|
||||||
class ZerosLinearTransformInitializer(AbstractLinearTransformInitializer):
|
class ZerosLinearTransformInitializer(AbstractLinearTransformInitializer):
|
||||||
"""Initialize a matrix with zeros."""
|
"""Initialize a matrix with zeros."""
|
||||||
|
|
||||||
def generate(self, in_dim: int, out_dim: int):
|
def generate(self, in_dim: int, out_dim: int):
|
||||||
weights = torch.zeros(in_dim, out_dim)
|
weights = torch.zeros(in_dim, out_dim)
|
||||||
return self.generate_end_hook(weights)
|
return self.generate_end_hook(weights)
|
||||||
@ -424,6 +459,7 @@ class ZerosLinearTransformInitializer(AbstractLinearTransformInitializer):
|
|||||||
|
|
||||||
class OnesLinearTransformInitializer(AbstractLinearTransformInitializer):
|
class OnesLinearTransformInitializer(AbstractLinearTransformInitializer):
|
||||||
"""Initialize a matrix with ones."""
|
"""Initialize a matrix with ones."""
|
||||||
|
|
||||||
def generate(self, in_dim: int, out_dim: int):
|
def generate(self, in_dim: int, out_dim: int):
|
||||||
weights = torch.ones(in_dim, out_dim)
|
weights = torch.ones(in_dim, out_dim)
|
||||||
return self.generate_end_hook(weights)
|
return self.generate_end_hook(weights)
|
||||||
@ -431,6 +467,7 @@ class OnesLinearTransformInitializer(AbstractLinearTransformInitializer):
|
|||||||
|
|
||||||
class EyeTransformInitializer(AbstractLinearTransformInitializer):
|
class EyeTransformInitializer(AbstractLinearTransformInitializer):
|
||||||
"""Initialize a matrix with the largest possible identity matrix."""
|
"""Initialize a matrix with the largest possible identity matrix."""
|
||||||
|
|
||||||
def generate(self, in_dim: int, out_dim: int):
|
def generate(self, in_dim: int, out_dim: int):
|
||||||
weights = torch.zeros(in_dim, out_dim)
|
weights = torch.zeros(in_dim, out_dim)
|
||||||
I = torch.eye(min(in_dim, out_dim))
|
I = torch.eye(min(in_dim, out_dim))
|
||||||
@ -440,6 +477,7 @@ class EyeTransformInitializer(AbstractLinearTransformInitializer):
|
|||||||
|
|
||||||
class AbstractDataAwareLTInitializer(AbstractLinearTransformInitializer):
|
class AbstractDataAwareLTInitializer(AbstractLinearTransformInitializer):
|
||||||
"""Abstract class for all data-aware linear transform initializers."""
|
"""Abstract class for all data-aware linear transform initializers."""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
data: torch.Tensor,
|
data: torch.Tensor,
|
||||||
noise: float = 0.0,
|
noise: float = 0.0,
|
||||||
|
Loading…
Reference in New Issue
Block a user