From 11cd1b0032a8bf1000af947522575d4bc4800281 Mon Sep 17 00:00:00 2001 From: Jensun Ravichandran Date: Wed, 16 Jun 2021 22:06:33 +0200 Subject: [PATCH] [BUGFIX] Add missing file --- prototorch/core/transforms.py | 44 +++++++++++++++++++++++++++++++++++ 1 file changed, 44 insertions(+) create mode 100644 prototorch/core/transforms.py diff --git a/prototorch/core/transforms.py b/prototorch/core/transforms.py new file mode 100644 index 0000000..3a0ded2 --- /dev/null +++ b/prototorch/core/transforms.py @@ -0,0 +1,44 @@ +"""ProtoTorch transforms""" + +import torch +from torch.nn.parameter import Parameter + +from .initializers import ( + AbstractLinearTransformInitializer, + EyeTransformInitializer, +) + + +class LinearTransform(torch.nn.Module): + def __init__( + self, + in_dim: int, + out_dim: int, + initializer: + AbstractLinearTransformInitializer = EyeTransformInitializer(), + **kwargs): + super().__init__(**kwargs) + self.set_weights(in_dim, out_dim, initializer) + + @property + def weights(self): + return self._weights.detach().cpu() + + def _register_weights(self, weights): + self.register_parameter("_weights", Parameter(weights)) + + def set_weights( + self, + in_dim: int, + out_dim: int, + initializer: + AbstractLinearTransformInitializer = EyeTransformInitializer()): + weights = initializer.generate(in_dim, out_dim) + self._register_weights(weights) + + def forward(self, x): + return x @ self.weights.T + + +# Aliases +Omega = LinearTransform