[FEATURE] Add transforms
This commit is contained in:
parent
c95f91cc29
commit
7a6da0c5fc
@ -14,9 +14,10 @@ from .core import (
|
||||
components,
|
||||
distances,
|
||||
initializers,
|
||||
similarities,
|
||||
losses,
|
||||
pooling,
|
||||
similarities,
|
||||
transforms,
|
||||
)
|
||||
|
||||
# Core Setup
|
||||
@ -33,6 +34,7 @@ __all_core__ = [
|
||||
"nn",
|
||||
"pooling",
|
||||
"similarities",
|
||||
"transforms",
|
||||
"utils",
|
||||
]
|
||||
|
||||
|
@ -7,3 +7,4 @@ from .initializers import *
|
||||
from .losses import *
|
||||
from .pooling import *
|
||||
from .similarities import *
|
||||
from .transforms import *
|
||||
|
@ -313,7 +313,7 @@ class AbstractReasoningsInitializer(ABC):
|
||||
@abstractmethod
|
||||
def generate(self, distribution: Union[dict, list, tuple]):
|
||||
...
|
||||
return generate_end_hook(...)
|
||||
return self.generate_end_hook(...)
|
||||
|
||||
|
||||
class LiteralReasoningsInitializer(AbstractReasoningsInitializer):
|
||||
@ -380,6 +380,51 @@ class PurePositiveReasoningsInitializer(AbstractReasoningsInitializer):
|
||||
return reasonings
|
||||
|
||||
|
||||
# Transforms
|
||||
class AbstractTransformInitializer(ABC):
|
||||
"""Abstract class for all transform initializers."""
|
||||
...
|
||||
|
||||
|
||||
class AbstractLinearTransformInitializer(AbstractTransformInitializer):
|
||||
"""Abstract class for all linear transform initializers."""
|
||||
def __init__(self, out_dim_first: bool = False):
|
||||
self.out_dim_first = out_dim_first
|
||||
|
||||
def generate_end_hook(self, weights):
|
||||
if self.out_dim_first:
|
||||
weights = weights.permute(1, 0)
|
||||
return weights
|
||||
|
||||
@abstractmethod
|
||||
def generate(self, in_dim: int, out_dim: int):
|
||||
...
|
||||
return self.generate_end_hook(...)
|
||||
|
||||
|
||||
class ZerosLinearTransformInitializer(AbstractLinearTransformInitializer):
|
||||
"""Initialize a matrix with zeros."""
|
||||
def generate(self, in_dim: int, out_dim: int):
|
||||
weights = torch.zeros(in_dim, out_dim)
|
||||
return self.generate_end_hook(weights)
|
||||
|
||||
|
||||
class OnesLinearTransformInitializer(AbstractLinearTransformInitializer):
|
||||
"""Initialize a matrix with ones."""
|
||||
def generate(self, in_dim: int, out_dim: int):
|
||||
weights = torch.ones(in_dim, out_dim)
|
||||
return self.generate_end_hook(weights)
|
||||
|
||||
|
||||
class EyeTransformInitializer(AbstractLinearTransformInitializer):
|
||||
"""Initialize a matrix with the largest possible identity matrix."""
|
||||
def generate(self, in_dim: int, out_dim: int):
|
||||
weights = torch.zeros(in_dim, out_dim)
|
||||
I = torch.eye(min(in_dim, out_dim))
|
||||
weights[:I.shape[0], :I.shape[1]] = I
|
||||
return self.generate_end_hook(weights)
|
||||
|
||||
|
||||
# Aliases - Components
|
||||
CACI = ClassAwareCompInitializer
|
||||
DACI = DataAwareCompInitializer
|
||||
@ -406,3 +451,8 @@ ORI = OnesReasoningsInitializer
|
||||
PPRI = PurePositiveReasoningsInitializer
|
||||
RRI = RandomReasoningsInitializer
|
||||
ZRI = ZerosReasoningsInitializer
|
||||
|
||||
# Aliases - Transforms
|
||||
Eye = EyeTransformInitializer
|
||||
OLTI = OnesLinearTransformInitializer
|
||||
ZLTI = ZerosLinearTransformInitializer
|
||||
|
@ -243,6 +243,56 @@ def test_pure_positive_reasonings_init_unrepresented_class():
|
||||
assert reasonings.shape[2] == 3
|
||||
|
||||
|
||||
# Transform initializers
|
||||
def test_eye_transform_init_square():
|
||||
t = pt.initializers.EyeTransformInitializer()
|
||||
I = t.generate(3, 3)
|
||||
assert torch.allclose(I, torch.eye(3))
|
||||
|
||||
|
||||
def test_eye_transform_init_narrow():
|
||||
t = pt.initializers.EyeTransformInitializer()
|
||||
actual = t.generate(3, 2)
|
||||
desired = torch.Tensor([[1, 0], [0, 1], [0, 0]])
|
||||
assert torch.allclose(actual, desired)
|
||||
|
||||
|
||||
def test_eye_transform_init_wide():
|
||||
t = pt.initializers.EyeTransformInitializer()
|
||||
actual = t.generate(2, 3)
|
||||
desired = torch.Tensor([[1, 0, 0], [0, 1, 0]])
|
||||
assert torch.allclose(actual, desired)
|
||||
|
||||
|
||||
# Transforms
|
||||
def test_linear_transform():
|
||||
l = pt.transforms.LinearTransform(2, 4)
|
||||
actual = l.weights
|
||||
desired = torch.Tensor([[1, 0, 0, 0], [0, 1, 0, 0]])
|
||||
assert torch.allclose(actual, desired)
|
||||
|
||||
|
||||
def test_linear_transform_zeros_init():
|
||||
l = pt.transforms.LinearTransform(
|
||||
in_dim=2,
|
||||
out_dim=4,
|
||||
initializer=pt.initializers.ZerosLinearTransformInitializer(),
|
||||
)
|
||||
actual = l.weights
|
||||
desired = torch.zeros(2, 4)
|
||||
assert torch.allclose(actual, desired)
|
||||
|
||||
|
||||
def test_linear_transform_out_dim_first():
|
||||
l = pt.transforms.LinearTransform(
|
||||
in_dim=2,
|
||||
out_dim=4,
|
||||
initializer=pt.initializers.OLTI(out_dim_first=True),
|
||||
)
|
||||
assert l.weights.shape[0] == 4
|
||||
assert l.weights.shape[1] == 2
|
||||
|
||||
|
||||
# Components
|
||||
def test_components_no_initializer():
|
||||
with pytest.raises(TypeError):
|
||||
|
Loading…
Reference in New Issue
Block a user