[FEATURE] Add transforms

This commit is contained in:
Jensun Ravichandran 2021-06-16 21:53:36 +02:00
parent c95f91cc29
commit 7a6da0c5fc
4 changed files with 105 additions and 2 deletions

View File

@ -14,9 +14,10 @@ from .core import (
components,
distances,
initializers,
similarities,
losses,
pooling,
similarities,
transforms,
)
# Core Setup
@ -33,6 +34,7 @@ __all_core__ = [
"nn",
"pooling",
"similarities",
"transforms",
"utils",
]

View File

@ -7,3 +7,4 @@ from .initializers import *
from .losses import *
from .pooling import *
from .similarities import *
from .transforms import *

View File

@ -313,7 +313,7 @@ class AbstractReasoningsInitializer(ABC):
@abstractmethod
def generate(self, distribution: Union[dict, list, tuple]):
...
return generate_end_hook(...)
return self.generate_end_hook(...)
class LiteralReasoningsInitializer(AbstractReasoningsInitializer):
@ -380,6 +380,51 @@ class PurePositiveReasoningsInitializer(AbstractReasoningsInitializer):
return reasonings
# Transforms
class AbstractTransformInitializer(ABC):
"""Abstract class for all transform initializers."""
...
class AbstractLinearTransformInitializer(AbstractTransformInitializer):
"""Abstract class for all linear transform initializers."""
def __init__(self, out_dim_first: bool = False):
self.out_dim_first = out_dim_first
def generate_end_hook(self, weights):
if self.out_dim_first:
weights = weights.permute(1, 0)
return weights
@abstractmethod
def generate(self, in_dim: int, out_dim: int):
...
return self.generate_end_hook(...)
class ZerosLinearTransformInitializer(AbstractLinearTransformInitializer):
"""Initialize a matrix with zeros."""
def generate(self, in_dim: int, out_dim: int):
weights = torch.zeros(in_dim, out_dim)
return self.generate_end_hook(weights)
class OnesLinearTransformInitializer(AbstractLinearTransformInitializer):
"""Initialize a matrix with ones."""
def generate(self, in_dim: int, out_dim: int):
weights = torch.ones(in_dim, out_dim)
return self.generate_end_hook(weights)
class EyeTransformInitializer(AbstractLinearTransformInitializer):
"""Initialize a matrix with the largest possible identity matrix."""
def generate(self, in_dim: int, out_dim: int):
weights = torch.zeros(in_dim, out_dim)
I = torch.eye(min(in_dim, out_dim))
weights[:I.shape[0], :I.shape[1]] = I
return self.generate_end_hook(weights)
# Aliases - Components
CACI = ClassAwareCompInitializer
DACI = DataAwareCompInitializer
@ -406,3 +451,8 @@ ORI = OnesReasoningsInitializer
PPRI = PurePositiveReasoningsInitializer
RRI = RandomReasoningsInitializer
ZRI = ZerosReasoningsInitializer
# Aliases - Transforms
Eye = EyeTransformInitializer
OLTI = OnesLinearTransformInitializer
ZLTI = ZerosLinearTransformInitializer

View File

@ -243,6 +243,56 @@ def test_pure_positive_reasonings_init_unrepresented_class():
assert reasonings.shape[2] == 3
# Transform initializers
def test_eye_transform_init_square():
t = pt.initializers.EyeTransformInitializer()
I = t.generate(3, 3)
assert torch.allclose(I, torch.eye(3))
def test_eye_transform_init_narrow():
t = pt.initializers.EyeTransformInitializer()
actual = t.generate(3, 2)
desired = torch.Tensor([[1, 0], [0, 1], [0, 0]])
assert torch.allclose(actual, desired)
def test_eye_transform_init_wide():
t = pt.initializers.EyeTransformInitializer()
actual = t.generate(2, 3)
desired = torch.Tensor([[1, 0, 0], [0, 1, 0]])
assert torch.allclose(actual, desired)
# Transforms
def test_linear_transform():
l = pt.transforms.LinearTransform(2, 4)
actual = l.weights
desired = torch.Tensor([[1, 0, 0, 0], [0, 1, 0, 0]])
assert torch.allclose(actual, desired)
def test_linear_transform_zeros_init():
l = pt.transforms.LinearTransform(
in_dim=2,
out_dim=4,
initializer=pt.initializers.ZerosLinearTransformInitializer(),
)
actual = l.weights
desired = torch.zeros(2, 4)
assert torch.allclose(actual, desired)
def test_linear_transform_out_dim_first():
l = pt.transforms.LinearTransform(
in_dim=2,
out_dim=4,
initializer=pt.initializers.OLTI(out_dim_first=True),
)
assert l.weights.shape[0] == 4
assert l.weights.shape[1] == 2
# Components
def test_components_no_initializer():
with pytest.raises(TypeError):