From 7a6da0c5fcda6d4bdd7110efb33442a809265e94 Mon Sep 17 00:00:00 2001 From: Jensun Ravichandran Date: Wed, 16 Jun 2021 21:53:36 +0200 Subject: [PATCH] [FEATURE] Add transforms --- prototorch/__init__.py | 4 ++- prototorch/core/__init__.py | 1 + prototorch/core/initializers.py | 52 ++++++++++++++++++++++++++++++++- tests/test_core.py | 50 +++++++++++++++++++++++++++++++ 4 files changed, 105 insertions(+), 2 deletions(-) diff --git a/prototorch/__init__.py b/prototorch/__init__.py index d0ce2d6..3412d9c 100644 --- a/prototorch/__init__.py +++ b/prototorch/__init__.py @@ -14,9 +14,10 @@ from .core import ( components, distances, initializers, - similarities, losses, pooling, + similarities, + transforms, ) # Core Setup @@ -33,6 +34,7 @@ __all_core__ = [ "nn", "pooling", "similarities", + "transforms", "utils", ] diff --git a/prototorch/core/__init__.py b/prototorch/core/__init__.py index c205dfa..e5961c1 100644 --- a/prototorch/core/__init__.py +++ b/prototorch/core/__init__.py @@ -7,3 +7,4 @@ from .initializers import * from .losses import * from .pooling import * from .similarities import * +from .transforms import * diff --git a/prototorch/core/initializers.py b/prototorch/core/initializers.py index 6a7067b..f5d2743 100644 --- a/prototorch/core/initializers.py +++ b/prototorch/core/initializers.py @@ -313,7 +313,7 @@ class AbstractReasoningsInitializer(ABC): @abstractmethod def generate(self, distribution: Union[dict, list, tuple]): ... - return generate_end_hook(...) + return self.generate_end_hook(...) class LiteralReasoningsInitializer(AbstractReasoningsInitializer): @@ -380,6 +380,51 @@ class PurePositiveReasoningsInitializer(AbstractReasoningsInitializer): return reasonings +# Transforms +class AbstractTransformInitializer(ABC): + """Abstract class for all transform initializers.""" + ... + + +class AbstractLinearTransformInitializer(AbstractTransformInitializer): + """Abstract class for all linear transform initializers.""" + def __init__(self, out_dim_first: bool = False): + self.out_dim_first = out_dim_first + + def generate_end_hook(self, weights): + if self.out_dim_first: + weights = weights.permute(1, 0) + return weights + + @abstractmethod + def generate(self, in_dim: int, out_dim: int): + ... + return self.generate_end_hook(...) + + +class ZerosLinearTransformInitializer(AbstractLinearTransformInitializer): + """Initialize a matrix with zeros.""" + def generate(self, in_dim: int, out_dim: int): + weights = torch.zeros(in_dim, out_dim) + return self.generate_end_hook(weights) + + +class OnesLinearTransformInitializer(AbstractLinearTransformInitializer): + """Initialize a matrix with ones.""" + def generate(self, in_dim: int, out_dim: int): + weights = torch.ones(in_dim, out_dim) + return self.generate_end_hook(weights) + + +class EyeTransformInitializer(AbstractLinearTransformInitializer): + """Initialize a matrix with the largest possible identity matrix.""" + def generate(self, in_dim: int, out_dim: int): + weights = torch.zeros(in_dim, out_dim) + I = torch.eye(min(in_dim, out_dim)) + weights[:I.shape[0], :I.shape[1]] = I + return self.generate_end_hook(weights) + + # Aliases - Components CACI = ClassAwareCompInitializer DACI = DataAwareCompInitializer @@ -406,3 +451,8 @@ ORI = OnesReasoningsInitializer PPRI = PurePositiveReasoningsInitializer RRI = RandomReasoningsInitializer ZRI = ZerosReasoningsInitializer + +# Aliases - Transforms +Eye = EyeTransformInitializer +OLTI = OnesLinearTransformInitializer +ZLTI = ZerosLinearTransformInitializer diff --git a/tests/test_core.py b/tests/test_core.py index 6fad03f..f949037 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -243,6 +243,56 @@ def test_pure_positive_reasonings_init_unrepresented_class(): assert reasonings.shape[2] == 3 +# Transform initializers +def test_eye_transform_init_square(): + t = pt.initializers.EyeTransformInitializer() + I = t.generate(3, 3) + assert torch.allclose(I, torch.eye(3)) + + +def test_eye_transform_init_narrow(): + t = pt.initializers.EyeTransformInitializer() + actual = t.generate(3, 2) + desired = torch.Tensor([[1, 0], [0, 1], [0, 0]]) + assert torch.allclose(actual, desired) + + +def test_eye_transform_init_wide(): + t = pt.initializers.EyeTransformInitializer() + actual = t.generate(2, 3) + desired = torch.Tensor([[1, 0, 0], [0, 1, 0]]) + assert torch.allclose(actual, desired) + + +# Transforms +def test_linear_transform(): + l = pt.transforms.LinearTransform(2, 4) + actual = l.weights + desired = torch.Tensor([[1, 0, 0, 0], [0, 1, 0, 0]]) + assert torch.allclose(actual, desired) + + +def test_linear_transform_zeros_init(): + l = pt.transforms.LinearTransform( + in_dim=2, + out_dim=4, + initializer=pt.initializers.ZerosLinearTransformInitializer(), + ) + actual = l.weights + desired = torch.zeros(2, 4) + assert torch.allclose(actual, desired) + + +def test_linear_transform_out_dim_first(): + l = pt.transforms.LinearTransform( + in_dim=2, + out_dim=4, + initializer=pt.initializers.OLTI(out_dim_first=True), + ) + assert l.weights.shape[0] == 4 + assert l.weights.shape[1] == 2 + + # Components def test_components_no_initializer(): with pytest.raises(TypeError):