[BUGFIX] Add missing file
This commit is contained in:
parent
7a6da0c5fc
commit
11cd1b0032
44
prototorch/core/transforms.py
Normal file
44
prototorch/core/transforms.py
Normal file
@ -0,0 +1,44 @@
|
|||||||
|
"""ProtoTorch transforms"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.nn.parameter import Parameter
|
||||||
|
|
||||||
|
from .initializers import (
|
||||||
|
AbstractLinearTransformInitializer,
|
||||||
|
EyeTransformInitializer,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class LinearTransform(torch.nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_dim: int,
|
||||||
|
out_dim: int,
|
||||||
|
initializer:
|
||||||
|
AbstractLinearTransformInitializer = EyeTransformInitializer(),
|
||||||
|
**kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self.set_weights(in_dim, out_dim, initializer)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def weights(self):
|
||||||
|
return self._weights.detach().cpu()
|
||||||
|
|
||||||
|
def _register_weights(self, weights):
|
||||||
|
self.register_parameter("_weights", Parameter(weights))
|
||||||
|
|
||||||
|
def set_weights(
|
||||||
|
self,
|
||||||
|
in_dim: int,
|
||||||
|
out_dim: int,
|
||||||
|
initializer:
|
||||||
|
AbstractLinearTransformInitializer = EyeTransformInitializer()):
|
||||||
|
weights = initializer.generate(in_dim, out_dim)
|
||||||
|
self._register_weights(weights)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return x @ self.weights.T
|
||||||
|
|
||||||
|
|
||||||
|
# Aliases
|
||||||
|
Omega = LinearTransform
|
Loading…
Reference in New Issue
Block a user