[FEATURE] Add transforms
This commit is contained in:
parent
c95f91cc29
commit
7a6da0c5fc
@ -14,9 +14,10 @@ from .core import (
|
|||||||
components,
|
components,
|
||||||
distances,
|
distances,
|
||||||
initializers,
|
initializers,
|
||||||
similarities,
|
|
||||||
losses,
|
losses,
|
||||||
pooling,
|
pooling,
|
||||||
|
similarities,
|
||||||
|
transforms,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Core Setup
|
# Core Setup
|
||||||
@ -33,6 +34,7 @@ __all_core__ = [
|
|||||||
"nn",
|
"nn",
|
||||||
"pooling",
|
"pooling",
|
||||||
"similarities",
|
"similarities",
|
||||||
|
"transforms",
|
||||||
"utils",
|
"utils",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -7,3 +7,4 @@ from .initializers import *
|
|||||||
from .losses import *
|
from .losses import *
|
||||||
from .pooling import *
|
from .pooling import *
|
||||||
from .similarities import *
|
from .similarities import *
|
||||||
|
from .transforms import *
|
||||||
|
@ -313,7 +313,7 @@ class AbstractReasoningsInitializer(ABC):
|
|||||||
@abstractmethod
|
@abstractmethod
|
||||||
def generate(self, distribution: Union[dict, list, tuple]):
|
def generate(self, distribution: Union[dict, list, tuple]):
|
||||||
...
|
...
|
||||||
return generate_end_hook(...)
|
return self.generate_end_hook(...)
|
||||||
|
|
||||||
|
|
||||||
class LiteralReasoningsInitializer(AbstractReasoningsInitializer):
|
class LiteralReasoningsInitializer(AbstractReasoningsInitializer):
|
||||||
@ -380,6 +380,51 @@ class PurePositiveReasoningsInitializer(AbstractReasoningsInitializer):
|
|||||||
return reasonings
|
return reasonings
|
||||||
|
|
||||||
|
|
||||||
|
# Transforms
|
||||||
|
class AbstractTransformInitializer(ABC):
|
||||||
|
"""Abstract class for all transform initializers."""
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class AbstractLinearTransformInitializer(AbstractTransformInitializer):
|
||||||
|
"""Abstract class for all linear transform initializers."""
|
||||||
|
def __init__(self, out_dim_first: bool = False):
|
||||||
|
self.out_dim_first = out_dim_first
|
||||||
|
|
||||||
|
def generate_end_hook(self, weights):
|
||||||
|
if self.out_dim_first:
|
||||||
|
weights = weights.permute(1, 0)
|
||||||
|
return weights
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def generate(self, in_dim: int, out_dim: int):
|
||||||
|
...
|
||||||
|
return self.generate_end_hook(...)
|
||||||
|
|
||||||
|
|
||||||
|
class ZerosLinearTransformInitializer(AbstractLinearTransformInitializer):
|
||||||
|
"""Initialize a matrix with zeros."""
|
||||||
|
def generate(self, in_dim: int, out_dim: int):
|
||||||
|
weights = torch.zeros(in_dim, out_dim)
|
||||||
|
return self.generate_end_hook(weights)
|
||||||
|
|
||||||
|
|
||||||
|
class OnesLinearTransformInitializer(AbstractLinearTransformInitializer):
|
||||||
|
"""Initialize a matrix with ones."""
|
||||||
|
def generate(self, in_dim: int, out_dim: int):
|
||||||
|
weights = torch.ones(in_dim, out_dim)
|
||||||
|
return self.generate_end_hook(weights)
|
||||||
|
|
||||||
|
|
||||||
|
class EyeTransformInitializer(AbstractLinearTransformInitializer):
|
||||||
|
"""Initialize a matrix with the largest possible identity matrix."""
|
||||||
|
def generate(self, in_dim: int, out_dim: int):
|
||||||
|
weights = torch.zeros(in_dim, out_dim)
|
||||||
|
I = torch.eye(min(in_dim, out_dim))
|
||||||
|
weights[:I.shape[0], :I.shape[1]] = I
|
||||||
|
return self.generate_end_hook(weights)
|
||||||
|
|
||||||
|
|
||||||
# Aliases - Components
|
# Aliases - Components
|
||||||
CACI = ClassAwareCompInitializer
|
CACI = ClassAwareCompInitializer
|
||||||
DACI = DataAwareCompInitializer
|
DACI = DataAwareCompInitializer
|
||||||
@ -406,3 +451,8 @@ ORI = OnesReasoningsInitializer
|
|||||||
PPRI = PurePositiveReasoningsInitializer
|
PPRI = PurePositiveReasoningsInitializer
|
||||||
RRI = RandomReasoningsInitializer
|
RRI = RandomReasoningsInitializer
|
||||||
ZRI = ZerosReasoningsInitializer
|
ZRI = ZerosReasoningsInitializer
|
||||||
|
|
||||||
|
# Aliases - Transforms
|
||||||
|
Eye = EyeTransformInitializer
|
||||||
|
OLTI = OnesLinearTransformInitializer
|
||||||
|
ZLTI = ZerosLinearTransformInitializer
|
||||||
|
@ -243,6 +243,56 @@ def test_pure_positive_reasonings_init_unrepresented_class():
|
|||||||
assert reasonings.shape[2] == 3
|
assert reasonings.shape[2] == 3
|
||||||
|
|
||||||
|
|
||||||
|
# Transform initializers
|
||||||
|
def test_eye_transform_init_square():
|
||||||
|
t = pt.initializers.EyeTransformInitializer()
|
||||||
|
I = t.generate(3, 3)
|
||||||
|
assert torch.allclose(I, torch.eye(3))
|
||||||
|
|
||||||
|
|
||||||
|
def test_eye_transform_init_narrow():
|
||||||
|
t = pt.initializers.EyeTransformInitializer()
|
||||||
|
actual = t.generate(3, 2)
|
||||||
|
desired = torch.Tensor([[1, 0], [0, 1], [0, 0]])
|
||||||
|
assert torch.allclose(actual, desired)
|
||||||
|
|
||||||
|
|
||||||
|
def test_eye_transform_init_wide():
|
||||||
|
t = pt.initializers.EyeTransformInitializer()
|
||||||
|
actual = t.generate(2, 3)
|
||||||
|
desired = torch.Tensor([[1, 0, 0], [0, 1, 0]])
|
||||||
|
assert torch.allclose(actual, desired)
|
||||||
|
|
||||||
|
|
||||||
|
# Transforms
|
||||||
|
def test_linear_transform():
|
||||||
|
l = pt.transforms.LinearTransform(2, 4)
|
||||||
|
actual = l.weights
|
||||||
|
desired = torch.Tensor([[1, 0, 0, 0], [0, 1, 0, 0]])
|
||||||
|
assert torch.allclose(actual, desired)
|
||||||
|
|
||||||
|
|
||||||
|
def test_linear_transform_zeros_init():
|
||||||
|
l = pt.transforms.LinearTransform(
|
||||||
|
in_dim=2,
|
||||||
|
out_dim=4,
|
||||||
|
initializer=pt.initializers.ZerosLinearTransformInitializer(),
|
||||||
|
)
|
||||||
|
actual = l.weights
|
||||||
|
desired = torch.zeros(2, 4)
|
||||||
|
assert torch.allclose(actual, desired)
|
||||||
|
|
||||||
|
|
||||||
|
def test_linear_transform_out_dim_first():
|
||||||
|
l = pt.transforms.LinearTransform(
|
||||||
|
in_dim=2,
|
||||||
|
out_dim=4,
|
||||||
|
initializer=pt.initializers.OLTI(out_dim_first=True),
|
||||||
|
)
|
||||||
|
assert l.weights.shape[0] == 4
|
||||||
|
assert l.weights.shape[1] == 2
|
||||||
|
|
||||||
|
|
||||||
# Components
|
# Components
|
||||||
def test_components_no_initializer():
|
def test_components_no_initializer():
|
||||||
with pytest.raises(TypeError):
|
with pytest.raises(TypeError):
|
||||||
|
Loading…
Reference in New Issue
Block a user