feat: add RandomLinearTransformInitializer

This commit is contained in:
Jensun Ravichandran 2022-04-04 20:55:03 +02:00
parent 85f75bb28c
commit 71a2e74eff
No known key found for this signature in database
GPG Key ID: 7612C0CAB643D921

View File

@ -465,7 +465,15 @@ class OnesLinearTransformInitializer(AbstractLinearTransformInitializer):
return self.generate_end_hook(weights)
class EyeTransformInitializer(AbstractLinearTransformInitializer):
class RandomLinearTransformInitializer(AbstractLinearTransformInitializer):
"""Initialize a matrix with random values."""
def generate(self, in_dim: int, out_dim: int):
weights = torch.rand(in_dim, out_dim)
return self.generate_end_hook(weights)
class EyeLinearTransformInitializer(AbstractLinearTransformInitializer):
"""Initialize a matrix with the largest possible identity matrix."""
def generate(self, in_dim: int, out_dim: int):
@ -539,8 +547,9 @@ RRI = RandomReasoningsInitializer
ZRI = ZerosReasoningsInitializer
# Aliases - Transforms
Eye = EyeTransformInitializer
ELTI = Eye = EyeLinearTransformInitializer
OLTI = OnesLinearTransformInitializer
RLTI = RandomLinearTransformInitializer
ZLTI = ZerosLinearTransformInitializer
PCALTI = PCALinearTransformInitializer
LLTI = LiteralLinearTransformInitializer