diff --git a/prototorch/components/initializers.py b/prototorch/components/initializers.py index 4582d6d..7f5f904 100644 --- a/prototorch/components/initializers.py +++ b/prototorch/components/initializers.py @@ -1,4 +1,5 @@ -"""ProtoTroch Initializers.""" +"""ProtoTroch Component and Label Initializers.""" + import warnings from collections.abc import Iterable from itertools import chain diff --git a/prototorch/modules/__init__.py b/prototorch/modules/__init__.py index fc7ab87..5675094 100644 --- a/prototorch/modules/__init__.py +++ b/prototorch/modules/__init__.py @@ -1,5 +1,7 @@ """ProtoTorch modules.""" from .competitions import * +from .initializers import * from .pooling import * +from .transformations import * from .wrappers import LambdaLayer, LossLayer diff --git a/prototorch/modules/initializers.py b/prototorch/modules/initializers.py new file mode 100644 index 0000000..e31c8f3 --- /dev/null +++ b/prototorch/modules/initializers.py @@ -0,0 +1,61 @@ +"""ProtoTroch Module Initializers.""" + +import torch + + +# Transformations +class MatrixInitializer(object): + def __init__(self, *args, **kwargs): + ... + + def generate(self, shape): + raise NotImplementedError("Subclasses should implement this!") + + +class ZerosInitializer(MatrixInitializer): + def generate(self, shape): + return torch.zeros(shape) + + +class OnesInitializer(MatrixInitializer): + def __init__(self, scale=1.0): + super().__init__() + self.scale = scale + + def generate(self, shape): + return torch.ones(shape) * self.scale + + +class UniformInitializer(MatrixInitializer): + def __init__(self, minimum=0.0, maximum=1.0, scale=1.0): + super().__init__() + self.minimum = minimum + self.maximum = maximum + self.scale = scale + + def generate(self, shape): + return torch.ones(shape).uniform_(self.minimum, + self.maximum) * self.scale + + +class DataAwareInitializer(MatrixInitializer): + def __init__(self, data, transform=torch.nn.Identity()): + super().__init__() + self.data = data + self.transform = transform + + def __del__(self): + del self.data + + +class EigenVectorInitializer(DataAwareInitializer): + def generate(self, shape): + # TODO + raise NotImplementedError() + + +# Aliases +EV = EigenVectorInitializer +Random = RandomInitializer = UniformInitializer +Zeros = ZerosInitializer +Ones = OnesInitializer diff --git a/prototorch/modules/transformations.py b/prototorch/modules/transformations.py new file mode 100644 index 0000000..a2f7e99 --- /dev/null +++ b/prototorch/modules/transformations.py @@ -0,0 +1,49 @@ +"""ProtoTorch Transformation Layers.""" + +import torch +from torch.nn.parameter import Parameter + +from .initializers import MatrixInitializer + + +def _precheck_initializer(initializer): + if not isinstance(initializer, MatrixInitializer): + emsg = f"`initializer` has to be some subtype of " \ + f"{MatrixInitializer}. " \ + f"You have provided: {initializer=} instead." + raise TypeError(emsg) + + +class Omega(torch.nn.Module): + """The Omega mapping used in GMLVQ.""" + def __init__(self, + num_replicas=1, + input_dim=None, + latent_dim=None, + initializer=None, + *, + initialized_weights=None): + super().__init__() + + if initialized_weights is not None: + self._register_weights(initialized_weights) + else: + if num_replicas == 1: + shape = (input_dim, latent_dim) + else: + shape = (num_replicas, input_dim, latent_dim) + self._initialize_weights(shape, initializer) + + def _register_weights(self, weights): + self.register_parameter("_omega", Parameter(weights)) + + def _initialize_weights(self, shape, initializer): + _precheck_initializer(initializer) + _omega = initializer.generate(shape) + self._register_weights(_omega) + + def forward(self): + return self._omega + + def extra_repr(self): + return f"(omega): (shape: {tuple(self._omega.shape)})"