diff --git a/prototorch/core/initializers.py b/prototorch/core/initializers.py index 2f9137d..79101f4 100644 --- a/prototorch/core/initializers.py +++ b/prototorch/core/initializers.py @@ -26,11 +26,18 @@ class LiteralCompInitializer(AbstractComponentsInitializer): Use this to 'generate' pre-initialized components elsewhere. """ + def __init__(self, components): self.components = components def generate(self, num_components: int = 0): """Ignore `num_components` and simply return `self.components`.""" + provided_num_components = len(self.components) + if provided_num_components != num_components: + wmsg = f"The number of components ({provided_num_components}) " \ + f"provided to {self.__class__.__name__} " \ + f"does not match the expected number ({num_components})." + warnings.warn(wmsg) if not isinstance(self.components, torch.Tensor): wmsg = f"Converting components to {torch.Tensor}..." warnings.warn(wmsg) @@ -40,6 +47,7 @@ class LiteralCompInitializer(AbstractComponentsInitializer): class ShapeAwareCompInitializer(AbstractComponentsInitializer): """Abstract class for all dimension-aware components initializers.""" + def __init__(self, shape: Union[Iterable, int]): if isinstance(shape, Iterable): self.component_shape = tuple(shape) @@ -53,6 +61,7 @@ class ShapeAwareCompInitializer(AbstractComponentsInitializer): class ZerosCompInitializer(ShapeAwareCompInitializer): """Generate zeros corresponding to the components shape.""" + def generate(self, num_components: int): components = torch.zeros((num_components, ) + self.component_shape) return components @@ -60,6 +69,7 @@ class ZerosCompInitializer(ShapeAwareCompInitializer): class OnesCompInitializer(ShapeAwareCompInitializer): """Generate ones corresponding to the components shape.""" + def generate(self, num_components: int): components = torch.ones((num_components, ) + self.component_shape) return components @@ -67,6 +77,7 @@ class OnesCompInitializer(ShapeAwareCompInitializer): class FillValueCompInitializer(OnesCompInitializer): """Generate components with the provided `fill_value`.""" + def __init__(self, shape, fill_value: float = 1.0): super().__init__(shape) self.fill_value = fill_value @@ -79,6 +90,7 @@ class FillValueCompInitializer(OnesCompInitializer): class UniformCompInitializer(OnesCompInitializer): """Generate components by sampling from a continuous uniform distribution.""" + def __init__(self, shape, minimum=0.0, maximum=1.0, scale=1.0): super().__init__(shape) self.minimum = minimum @@ -93,6 +105,7 @@ class UniformCompInitializer(OnesCompInitializer): class RandomNormalCompInitializer(OnesCompInitializer): """Generate components by sampling from a standard normal distribution.""" + def __init__(self, shape, shift=0.0, scale=1.0): super().__init__(shape) self.shift = shift @@ -113,6 +126,7 @@ class AbstractDataAwareCompInitializer(AbstractComponentsInitializer): `data` has to be a torch tensor. """ + def __init__(self, data: torch.Tensor, noise: float = 0.0, @@ -137,6 +151,7 @@ class AbstractDataAwareCompInitializer(AbstractComponentsInitializer): class DataAwareCompInitializer(AbstractDataAwareCompInitializer): """'Generate' the components from the provided data.""" + def generate(self, num_components: int = 0): """Ignore `num_components` and simply return transformed `self.data`.""" components = self.generate_end_hook(self.data) @@ -145,6 +160,7 @@ class DataAwareCompInitializer(AbstractDataAwareCompInitializer): class SelectionCompInitializer(AbstractDataAwareCompInitializer): """Generate components by uniformly sampling from the provided data.""" + def generate(self, num_components: int): indices = torch.LongTensor(num_components).random_(0, len(self.data)) samples = self.data[indices] @@ -154,6 +170,7 @@ class SelectionCompInitializer(AbstractDataAwareCompInitializer): class MeanCompInitializer(AbstractDataAwareCompInitializer): """Generate components by computing the mean of the provided data.""" + def generate(self, num_components: int): mean = self.data.mean(dim=0) repeat_dim = [num_components] + [1] * len(mean.shape) @@ -172,6 +189,7 @@ class AbstractClassAwareCompInitializer(AbstractComponentsInitializer): target tensors. """ + def __init__(self, data, noise: float = 0.0, @@ -199,6 +217,7 @@ class AbstractClassAwareCompInitializer(AbstractComponentsInitializer): class ClassAwareCompInitializer(AbstractClassAwareCompInitializer): """'Generate' components from provided data and requested distribution.""" + def generate(self, distribution: Union[dict, list, tuple]): """Ignore `distribution` and simply return transformed `self.data`.""" components = self.generate_end_hook(self.data) @@ -207,6 +226,7 @@ class ClassAwareCompInitializer(AbstractClassAwareCompInitializer): class AbstractStratifiedCompInitializer(AbstractClassAwareCompInitializer): """Abstract class for all stratified components initializers.""" + @property @abstractmethod def subinit_type(self) -> Type[AbstractDataAwareCompInitializer]: @@ -231,6 +251,7 @@ class AbstractStratifiedCompInitializer(AbstractClassAwareCompInitializer): class StratifiedSelectionCompInitializer(AbstractStratifiedCompInitializer): """Generate components using stratified sampling from the provided data.""" + @property def subinit_type(self): return SelectionCompInitializer @@ -238,6 +259,7 @@ class StratifiedSelectionCompInitializer(AbstractStratifiedCompInitializer): class StratifiedMeanCompInitializer(AbstractStratifiedCompInitializer): """Generate components at stratified means of the provided data.""" + @property def subinit_type(self): return MeanCompInitializer @@ -246,6 +268,7 @@ class StratifiedMeanCompInitializer(AbstractStratifiedCompInitializer): # Labels class AbstractLabelsInitializer(ABC): """Abstract class for all labels initializers.""" + @abstractmethod def generate(self, distribution: Union[dict, list, tuple]): ... @@ -257,6 +280,7 @@ class LiteralLabelsInitializer(AbstractLabelsInitializer): Use this to 'generate' pre-initialized labels elsewhere. """ + def __init__(self, labels): self.labels = labels @@ -275,6 +299,7 @@ class LiteralLabelsInitializer(AbstractLabelsInitializer): class DataAwareLabelsInitializer(AbstractLabelsInitializer): """'Generate' the labels from a torch Dataset.""" + def __init__(self, data): self.data, self.targets = parse_data_arg(data) @@ -285,6 +310,7 @@ class DataAwareLabelsInitializer(AbstractLabelsInitializer): class LabelsInitializer(AbstractLabelsInitializer): """Generate labels from `distribution`.""" + def generate(self, distribution: Union[dict, list, tuple]): distribution = parse_distribution(distribution) labels_list = [] @@ -296,6 +322,7 @@ class LabelsInitializer(AbstractLabelsInitializer): class OneHotLabelsInitializer(LabelsInitializer): """Generate one-hot-encoded labels from `distribution`.""" + def generate(self, distribution: Union[dict, list, tuple]): distribution = parse_distribution(distribution) num_classes = len(distribution.keys()) @@ -314,6 +341,7 @@ def compute_distribution_shape(distribution): class AbstractReasoningsInitializer(ABC): """Abstract class for all reasonings initializers.""" + def __init__(self, components_first: bool = True): self.components_first = components_first @@ -334,6 +362,7 @@ class LiteralReasoningsInitializer(AbstractReasoningsInitializer): Use this to 'generate' pre-initialized reasonings elsewhere. """ + def __init__(self, reasonings, **kwargs): super().__init__(**kwargs) self.reasonings = reasonings @@ -351,6 +380,7 @@ class LiteralReasoningsInitializer(AbstractReasoningsInitializer): class ZerosReasoningsInitializer(AbstractReasoningsInitializer): """Reasonings are all initialized with zeros.""" + def generate(self, distribution: Union[dict, list, tuple]): shape = compute_distribution_shape(distribution) reasonings = torch.zeros(*shape) @@ -360,6 +390,7 @@ class ZerosReasoningsInitializer(AbstractReasoningsInitializer): class OnesReasoningsInitializer(AbstractReasoningsInitializer): """Reasonings are all initialized with ones.""" + def generate(self, distribution: Union[dict, list, tuple]): shape = compute_distribution_shape(distribution) reasonings = torch.ones(*shape) @@ -369,6 +400,7 @@ class OnesReasoningsInitializer(AbstractReasoningsInitializer): class RandomReasoningsInitializer(AbstractReasoningsInitializer): """Reasonings are randomly initialized.""" + def __init__(self, minimum=0.4, maximum=0.6, **kwargs): super().__init__(**kwargs) self.minimum = minimum @@ -383,6 +415,7 @@ class RandomReasoningsInitializer(AbstractReasoningsInitializer): class PurePositiveReasoningsInitializer(AbstractReasoningsInitializer): """Each component reasons positively for exactly one class.""" + def generate(self, distribution: Union[dict, list, tuple]): num_components, num_classes, _ = compute_distribution_shape( distribution) @@ -401,6 +434,7 @@ class AbstractTransformInitializer(ABC): class AbstractLinearTransformInitializer(AbstractTransformInitializer): """Abstract class for all linear transform initializers.""" + def __init__(self, out_dim_first: bool = False): self.out_dim_first = out_dim_first @@ -417,6 +451,7 @@ class AbstractLinearTransformInitializer(AbstractTransformInitializer): class ZerosLinearTransformInitializer(AbstractLinearTransformInitializer): """Initialize a matrix with zeros.""" + def generate(self, in_dim: int, out_dim: int): weights = torch.zeros(in_dim, out_dim) return self.generate_end_hook(weights) @@ -424,6 +459,7 @@ class ZerosLinearTransformInitializer(AbstractLinearTransformInitializer): class OnesLinearTransformInitializer(AbstractLinearTransformInitializer): """Initialize a matrix with ones.""" + def generate(self, in_dim: int, out_dim: int): weights = torch.ones(in_dim, out_dim) return self.generate_end_hook(weights) @@ -431,6 +467,7 @@ class OnesLinearTransformInitializer(AbstractLinearTransformInitializer): class EyeTransformInitializer(AbstractLinearTransformInitializer): """Initialize a matrix with the largest possible identity matrix.""" + def generate(self, in_dim: int, out_dim: int): weights = torch.zeros(in_dim, out_dim) I = torch.eye(min(in_dim, out_dim)) @@ -440,6 +477,7 @@ class EyeTransformInitializer(AbstractLinearTransformInitializer): class AbstractDataAwareLTInitializer(AbstractLinearTransformInitializer): """Abstract class for all data-aware linear transform initializers.""" + def __init__(self, data: torch.Tensor, noise: float = 0.0,