feat: add RandomLinearTransformInitializer

This commit is contained in:
Jensun Ravichandran 2022-04-04 20:55:03 +02:00
parent 85f75bb28c
commit 71a2e74eff
No known key found for this signature in database
GPG Key ID: 7612C0CAB643D921

View File

@ -465,7 +465,15 @@ class OnesLinearTransformInitializer(AbstractLinearTransformInitializer):
return self.generate_end_hook(weights) return self.generate_end_hook(weights)
class EyeTransformInitializer(AbstractLinearTransformInitializer): class RandomLinearTransformInitializer(AbstractLinearTransformInitializer):
"""Initialize a matrix with random values."""
def generate(self, in_dim: int, out_dim: int):
weights = torch.rand(in_dim, out_dim)
return self.generate_end_hook(weights)
class EyeLinearTransformInitializer(AbstractLinearTransformInitializer):
"""Initialize a matrix with the largest possible identity matrix.""" """Initialize a matrix with the largest possible identity matrix."""
def generate(self, in_dim: int, out_dim: int): def generate(self, in_dim: int, out_dim: int):
@ -539,8 +547,9 @@ RRI = RandomReasoningsInitializer
ZRI = ZerosReasoningsInitializer ZRI = ZerosReasoningsInitializer
# Aliases - Transforms # Aliases - Transforms
Eye = EyeTransformInitializer ELTI = Eye = EyeLinearTransformInitializer
OLTI = OnesLinearTransformInitializer OLTI = OnesLinearTransformInitializer
RLTI = RandomLinearTransformInitializer
ZLTI = ZerosLinearTransformInitializer ZLTI = ZerosLinearTransformInitializer
PCALTI = PCALinearTransformInitializer PCALTI = PCALinearTransformInitializer
LLTI = LiteralLinearTransformInitializer LLTI = LiteralLinearTransformInitializer