[FEATURE] Add transforms
This commit is contained in:
@@ -243,6 +243,56 @@ def test_pure_positive_reasonings_init_unrepresented_class():
|
||||
assert reasonings.shape[2] == 3
|
||||
|
||||
|
||||
# Transform initializers
|
||||
def test_eye_transform_init_square():
|
||||
t = pt.initializers.EyeTransformInitializer()
|
||||
I = t.generate(3, 3)
|
||||
assert torch.allclose(I, torch.eye(3))
|
||||
|
||||
|
||||
def test_eye_transform_init_narrow():
|
||||
t = pt.initializers.EyeTransformInitializer()
|
||||
actual = t.generate(3, 2)
|
||||
desired = torch.Tensor([[1, 0], [0, 1], [0, 0]])
|
||||
assert torch.allclose(actual, desired)
|
||||
|
||||
|
||||
def test_eye_transform_init_wide():
|
||||
t = pt.initializers.EyeTransformInitializer()
|
||||
actual = t.generate(2, 3)
|
||||
desired = torch.Tensor([[1, 0, 0], [0, 1, 0]])
|
||||
assert torch.allclose(actual, desired)
|
||||
|
||||
|
||||
# Transforms
|
||||
def test_linear_transform():
|
||||
l = pt.transforms.LinearTransform(2, 4)
|
||||
actual = l.weights
|
||||
desired = torch.Tensor([[1, 0, 0, 0], [0, 1, 0, 0]])
|
||||
assert torch.allclose(actual, desired)
|
||||
|
||||
|
||||
def test_linear_transform_zeros_init():
|
||||
l = pt.transforms.LinearTransform(
|
||||
in_dim=2,
|
||||
out_dim=4,
|
||||
initializer=pt.initializers.ZerosLinearTransformInitializer(),
|
||||
)
|
||||
actual = l.weights
|
||||
desired = torch.zeros(2, 4)
|
||||
assert torch.allclose(actual, desired)
|
||||
|
||||
|
||||
def test_linear_transform_out_dim_first():
|
||||
l = pt.transforms.LinearTransform(
|
||||
in_dim=2,
|
||||
out_dim=4,
|
||||
initializer=pt.initializers.OLTI(out_dim_first=True),
|
||||
)
|
||||
assert l.weights.shape[0] == 4
|
||||
assert l.weights.shape[1] == 2
|
||||
|
||||
|
||||
# Components
|
||||
def test_components_no_initializer():
|
||||
with pytest.raises(TypeError):
|
||||
|
Reference in New Issue
Block a user