[FEATURE] Add transforms
This commit is contained in:
		@@ -14,9 +14,10 @@ from .core import (
 | 
				
			|||||||
    components,
 | 
					    components,
 | 
				
			||||||
    distances,
 | 
					    distances,
 | 
				
			||||||
    initializers,
 | 
					    initializers,
 | 
				
			||||||
    similarities,
 | 
					 | 
				
			||||||
    losses,
 | 
					    losses,
 | 
				
			||||||
    pooling,
 | 
					    pooling,
 | 
				
			||||||
 | 
					    similarities,
 | 
				
			||||||
 | 
					    transforms,
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# Core Setup
 | 
					# Core Setup
 | 
				
			||||||
@@ -33,6 +34,7 @@ __all_core__ = [
 | 
				
			|||||||
    "nn",
 | 
					    "nn",
 | 
				
			||||||
    "pooling",
 | 
					    "pooling",
 | 
				
			||||||
    "similarities",
 | 
					    "similarities",
 | 
				
			||||||
 | 
					    "transforms",
 | 
				
			||||||
    "utils",
 | 
					    "utils",
 | 
				
			||||||
]
 | 
					]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -7,3 +7,4 @@ from .initializers import *
 | 
				
			|||||||
from .losses import *
 | 
					from .losses import *
 | 
				
			||||||
from .pooling import *
 | 
					from .pooling import *
 | 
				
			||||||
from .similarities import *
 | 
					from .similarities import *
 | 
				
			||||||
 | 
					from .transforms import *
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -313,7 +313,7 @@ class AbstractReasoningsInitializer(ABC):
 | 
				
			|||||||
    @abstractmethod
 | 
					    @abstractmethod
 | 
				
			||||||
    def generate(self, distribution: Union[dict, list, tuple]):
 | 
					    def generate(self, distribution: Union[dict, list, tuple]):
 | 
				
			||||||
        ...
 | 
					        ...
 | 
				
			||||||
        return generate_end_hook(...)
 | 
					        return self.generate_end_hook(...)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class LiteralReasoningsInitializer(AbstractReasoningsInitializer):
 | 
					class LiteralReasoningsInitializer(AbstractReasoningsInitializer):
 | 
				
			||||||
@@ -380,6 +380,51 @@ class PurePositiveReasoningsInitializer(AbstractReasoningsInitializer):
 | 
				
			|||||||
        return reasonings
 | 
					        return reasonings
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Transforms
 | 
				
			||||||
 | 
					class AbstractTransformInitializer(ABC):
 | 
				
			||||||
 | 
					    """Abstract class for all transform initializers."""
 | 
				
			||||||
 | 
					    ...
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class AbstractLinearTransformInitializer(AbstractTransformInitializer):
 | 
				
			||||||
 | 
					    """Abstract class for all linear transform initializers."""
 | 
				
			||||||
 | 
					    def __init__(self, out_dim_first: bool = False):
 | 
				
			||||||
 | 
					        self.out_dim_first = out_dim_first
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def generate_end_hook(self, weights):
 | 
				
			||||||
 | 
					        if self.out_dim_first:
 | 
				
			||||||
 | 
					            weights = weights.permute(1, 0)
 | 
				
			||||||
 | 
					        return weights
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @abstractmethod
 | 
				
			||||||
 | 
					    def generate(self, in_dim: int, out_dim: int):
 | 
				
			||||||
 | 
					        ...
 | 
				
			||||||
 | 
					        return self.generate_end_hook(...)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class ZerosLinearTransformInitializer(AbstractLinearTransformInitializer):
 | 
				
			||||||
 | 
					    """Initialize a matrix with zeros."""
 | 
				
			||||||
 | 
					    def generate(self, in_dim: int, out_dim: int):
 | 
				
			||||||
 | 
					        weights = torch.zeros(in_dim, out_dim)
 | 
				
			||||||
 | 
					        return self.generate_end_hook(weights)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class OnesLinearTransformInitializer(AbstractLinearTransformInitializer):
 | 
				
			||||||
 | 
					    """Initialize a matrix with ones."""
 | 
				
			||||||
 | 
					    def generate(self, in_dim: int, out_dim: int):
 | 
				
			||||||
 | 
					        weights = torch.ones(in_dim, out_dim)
 | 
				
			||||||
 | 
					        return self.generate_end_hook(weights)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class EyeTransformInitializer(AbstractLinearTransformInitializer):
 | 
				
			||||||
 | 
					    """Initialize a matrix with the largest possible identity matrix."""
 | 
				
			||||||
 | 
					    def generate(self, in_dim: int, out_dim: int):
 | 
				
			||||||
 | 
					        weights = torch.zeros(in_dim, out_dim)
 | 
				
			||||||
 | 
					        I = torch.eye(min(in_dim, out_dim))
 | 
				
			||||||
 | 
					        weights[:I.shape[0], :I.shape[1]] = I
 | 
				
			||||||
 | 
					        return self.generate_end_hook(weights)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# Aliases - Components
 | 
					# Aliases - Components
 | 
				
			||||||
CACI = ClassAwareCompInitializer
 | 
					CACI = ClassAwareCompInitializer
 | 
				
			||||||
DACI = DataAwareCompInitializer
 | 
					DACI = DataAwareCompInitializer
 | 
				
			||||||
@@ -406,3 +451,8 @@ ORI = OnesReasoningsInitializer
 | 
				
			|||||||
PPRI = PurePositiveReasoningsInitializer
 | 
					PPRI = PurePositiveReasoningsInitializer
 | 
				
			||||||
RRI = RandomReasoningsInitializer
 | 
					RRI = RandomReasoningsInitializer
 | 
				
			||||||
ZRI = ZerosReasoningsInitializer
 | 
					ZRI = ZerosReasoningsInitializer
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Aliases - Transforms
 | 
				
			||||||
 | 
					Eye = EyeTransformInitializer
 | 
				
			||||||
 | 
					OLTI = OnesLinearTransformInitializer
 | 
				
			||||||
 | 
					ZLTI = ZerosLinearTransformInitializer
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -243,6 +243,56 @@ def test_pure_positive_reasonings_init_unrepresented_class():
 | 
				
			|||||||
    assert reasonings.shape[2] == 3
 | 
					    assert reasonings.shape[2] == 3
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Transform initializers
 | 
				
			||||||
 | 
					def test_eye_transform_init_square():
 | 
				
			||||||
 | 
					    t = pt.initializers.EyeTransformInitializer()
 | 
				
			||||||
 | 
					    I = t.generate(3, 3)
 | 
				
			||||||
 | 
					    assert torch.allclose(I, torch.eye(3))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_eye_transform_init_narrow():
 | 
				
			||||||
 | 
					    t = pt.initializers.EyeTransformInitializer()
 | 
				
			||||||
 | 
					    actual = t.generate(3, 2)
 | 
				
			||||||
 | 
					    desired = torch.Tensor([[1, 0], [0, 1], [0, 0]])
 | 
				
			||||||
 | 
					    assert torch.allclose(actual, desired)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_eye_transform_init_wide():
 | 
				
			||||||
 | 
					    t = pt.initializers.EyeTransformInitializer()
 | 
				
			||||||
 | 
					    actual = t.generate(2, 3)
 | 
				
			||||||
 | 
					    desired = torch.Tensor([[1, 0, 0], [0, 1, 0]])
 | 
				
			||||||
 | 
					    assert torch.allclose(actual, desired)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Transforms
 | 
				
			||||||
 | 
					def test_linear_transform():
 | 
				
			||||||
 | 
					    l = pt.transforms.LinearTransform(2, 4)
 | 
				
			||||||
 | 
					    actual = l.weights
 | 
				
			||||||
 | 
					    desired = torch.Tensor([[1, 0, 0, 0], [0, 1, 0, 0]])
 | 
				
			||||||
 | 
					    assert torch.allclose(actual, desired)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_linear_transform_zeros_init():
 | 
				
			||||||
 | 
					    l = pt.transforms.LinearTransform(
 | 
				
			||||||
 | 
					        in_dim=2,
 | 
				
			||||||
 | 
					        out_dim=4,
 | 
				
			||||||
 | 
					        initializer=pt.initializers.ZerosLinearTransformInitializer(),
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    actual = l.weights
 | 
				
			||||||
 | 
					    desired = torch.zeros(2, 4)
 | 
				
			||||||
 | 
					    assert torch.allclose(actual, desired)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_linear_transform_out_dim_first():
 | 
				
			||||||
 | 
					    l = pt.transforms.LinearTransform(
 | 
				
			||||||
 | 
					        in_dim=2,
 | 
				
			||||||
 | 
					        out_dim=4,
 | 
				
			||||||
 | 
					        initializer=pt.initializers.OLTI(out_dim_first=True),
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    assert l.weights.shape[0] == 4
 | 
				
			||||||
 | 
					    assert l.weights.shape[1] == 2
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# Components
 | 
					# Components
 | 
				
			||||||
def test_components_no_initializer():
 | 
					def test_components_no_initializer():
 | 
				
			||||||
    with pytest.raises(TypeError):
 | 
					    with pytest.raises(TypeError):
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user