diff --git a/prototorch/core/transforms.py b/prototorch/core/transforms.py new file mode 100644 index 0000000..3a0ded2 --- /dev/null +++ b/prototorch/core/transforms.py @@ -0,0 +1,44 @@ +"""ProtoTorch transforms""" + +import torch +from torch.nn.parameter import Parameter + +from .initializers import ( + AbstractLinearTransformInitializer, + EyeTransformInitializer, +) + + +class LinearTransform(torch.nn.Module): + def __init__( + self, + in_dim: int, + out_dim: int, + initializer: + AbstractLinearTransformInitializer = EyeTransformInitializer(), + **kwargs): + super().__init__(**kwargs) + self.set_weights(in_dim, out_dim, initializer) + + @property + def weights(self): + return self._weights.detach().cpu() + + def _register_weights(self, weights): + self.register_parameter("_weights", Parameter(weights)) + + def set_weights( + self, + in_dim: int, + out_dim: int, + initializer: + AbstractLinearTransformInitializer = EyeTransformInitializer()): + weights = initializer.generate(in_dim, out_dim) + self._register_weights(weights) + + def forward(self, x): + return x @ self.weights.T + + +# Aliases +Omega = LinearTransform