[BUGFIX] Add missing file

This commit is contained in:
Jensun Ravichandran 2021-06-16 22:06:33 +02:00
parent 7a6da0c5fc
commit 11cd1b0032

View File

@ -0,0 +1,44 @@
"""ProtoTorch transforms"""
import torch
from torch.nn.parameter import Parameter
from .initializers import (
AbstractLinearTransformInitializer,
EyeTransformInitializer,
)
class LinearTransform(torch.nn.Module):
def __init__(
self,
in_dim: int,
out_dim: int,
initializer:
AbstractLinearTransformInitializer = EyeTransformInitializer(),
**kwargs):
super().__init__(**kwargs)
self.set_weights(in_dim, out_dim, initializer)
@property
def weights(self):
return self._weights.detach().cpu()
def _register_weights(self, weights):
self.register_parameter("_weights", Parameter(weights))
def set_weights(
self,
in_dim: int,
out_dim: int,
initializer:
AbstractLinearTransformInitializer = EyeTransformInitializer()):
weights = initializer.generate(in_dim, out_dim)
self._register_weights(weights)
def forward(self, x):
return x @ self.weights.T
# Aliases
Omega = LinearTransform