ci: add refurb to pre-commit config

This commit is contained in:
Alexander Engelsberger
2022-10-26 13:19:45 +02:00
parent 482044ec87
commit c547af728b
5 changed files with 14 additions and 8 deletions

View File

@@ -32,9 +32,9 @@ class SimpleComparisonMixin(BaseYArchitecture):
comparison_args: Keyword arguments for the comparison function. Default: {}.
"""
comparison_fn: Callable = euclidean_distance
comparison_args: dict = field(default_factory=lambda: dict())
comparison_args: dict = field(default_factory=dict)
comparison_parameters: dict = field(default_factory=lambda: dict())
comparison_parameters: dict = field(default_factory=dict)
# Steps
# ----------------------------------------------------------------------------------------------
@@ -44,7 +44,7 @@ class SimpleComparisonMixin(BaseYArchitecture):
**hparams.comparison_args,
)
self.comparison_kwargs: dict[str, Tensor] = dict()
self.comparison_kwargs: dict[str, Tensor] = {}
def comparison(self, batch, components):
comp_tensor, _ = components
@@ -86,7 +86,7 @@ class OmegaComparisonMixin(SimpleComparisonMixin):
latent_dim: int = 2
omega_initializer: type[
AbstractLinearTransformInitializer] = EyeLinearTransformInitializer
omega_initializer_kwargs: dict = field(default_factory=lambda: dict())
omega_initializer_kwargs: dict = field(default_factory=dict)
# Steps
# ----------------------------------------------------------------------------------------------

View File

@@ -46,7 +46,7 @@ class MultipleLearningRateMixin(BaseYArchitecture):
lr: The learning rate. Default: 0.1.
optimizer: The optimizer to use. Default: torch.optim.Adam.
"""
lr: dict = field(default_factory=lambda: dict())
lr: dict = field(default_factory=dict)
optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam
# Hooks