[BUGFIX] Add missing file
This commit is contained in:
parent
7a6da0c5fc
commit
11cd1b0032
44
prototorch/core/transforms.py
Normal file
44
prototorch/core/transforms.py
Normal file
@ -0,0 +1,44 @@
|
||||
"""ProtoTorch transforms"""
|
||||
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from .initializers import (
|
||||
AbstractLinearTransformInitializer,
|
||||
EyeTransformInitializer,
|
||||
)
|
||||
|
||||
|
||||
class LinearTransform(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_dim: int,
|
||||
out_dim: int,
|
||||
initializer:
|
||||
AbstractLinearTransformInitializer = EyeTransformInitializer(),
|
||||
**kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.set_weights(in_dim, out_dim, initializer)
|
||||
|
||||
@property
|
||||
def weights(self):
|
||||
return self._weights.detach().cpu()
|
||||
|
||||
def _register_weights(self, weights):
|
||||
self.register_parameter("_weights", Parameter(weights))
|
||||
|
||||
def set_weights(
|
||||
self,
|
||||
in_dim: int,
|
||||
out_dim: int,
|
||||
initializer:
|
||||
AbstractLinearTransformInitializer = EyeTransformInitializer()):
|
||||
weights = initializer.generate(in_dim, out_dim)
|
||||
self._register_weights(weights)
|
||||
|
||||
def forward(self, x):
|
||||
return x @ self.weights.T
|
||||
|
||||
|
||||
# Aliases
|
||||
Omega = LinearTransform
|
Loading…
Reference in New Issue
Block a user