fix: correct typo
This commit is contained in:
@@ -5,7 +5,7 @@ from torch.nn.parameter import Parameter
|
||||
|
||||
from .initializers import (
|
||||
AbstractLinearTransformInitializer,
|
||||
EyeTransformInitializer,
|
||||
EyeLinearTransformInitializer,
|
||||
)
|
||||
|
||||
|
||||
@@ -15,7 +15,7 @@ class LinearTransform(torch.nn.Module):
|
||||
in_dim: int,
|
||||
out_dim: int,
|
||||
initializer:
|
||||
AbstractLinearTransformInitializer = EyeTransformInitializer()):
|
||||
AbstractLinearTransformInitializer = EyeLinearTransformInitializer()):
|
||||
super().__init__()
|
||||
self.set_weights(in_dim, out_dim, initializer)
|
||||
|
||||
@@ -31,7 +31,7 @@ class LinearTransform(torch.nn.Module):
|
||||
in_dim: int,
|
||||
out_dim: int,
|
||||
initializer:
|
||||
AbstractLinearTransformInitializer = EyeTransformInitializer()):
|
||||
AbstractLinearTransformInitializer = EyeLinearTransformInitializer()):
|
||||
weights = initializer.generate(in_dim, out_dim)
|
||||
self._register_weights(weights)
|
||||
|
||||
|
Reference in New Issue
Block a user