From c547af728bea0f8b29c33c92ffbd831258ca2f38 Mon Sep 17 00:00:00 2001 From: Alexander Engelsberger Date: Wed, 26 Oct 2022 13:19:45 +0200 Subject: [PATCH] ci: add refurb to pre-commit config --- .pre-commit-config.yaml | 5 +++++ prototorch/models/architectures/comparison.py | 8 ++++---- prototorch/models/architectures/optimization.py | 2 +- prototorch/models/library/gmlvq.py | 2 +- setup.py | 5 +++-- 5 files changed, 14 insertions(+), 8 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 85263e4..df71728 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -52,3 +52,8 @@ repos: hooks: - id: gitlint args: [--contrib=CT1, --ignore=B6, --msg-filename] + +- repo: https://github.com/dosisod/refurb + rev: v1.4.0 + hooks: + - id: refurb diff --git a/prototorch/models/architectures/comparison.py b/prototorch/models/architectures/comparison.py index db7d898..02f498b 100644 --- a/prototorch/models/architectures/comparison.py +++ b/prototorch/models/architectures/comparison.py @@ -32,9 +32,9 @@ class SimpleComparisonMixin(BaseYArchitecture): comparison_args: Keyword arguments for the comparison function. Default: {}. """ comparison_fn: Callable = euclidean_distance - comparison_args: dict = field(default_factory=lambda: dict()) + comparison_args: dict = field(default_factory=dict) - comparison_parameters: dict = field(default_factory=lambda: dict()) + comparison_parameters: dict = field(default_factory=dict) # Steps # ---------------------------------------------------------------------------------------------- @@ -44,7 +44,7 @@ class SimpleComparisonMixin(BaseYArchitecture): **hparams.comparison_args, ) - self.comparison_kwargs: dict[str, Tensor] = dict() + self.comparison_kwargs: dict[str, Tensor] = {} def comparison(self, batch, components): comp_tensor, _ = components @@ -86,7 +86,7 @@ class OmegaComparisonMixin(SimpleComparisonMixin): latent_dim: int = 2 omega_initializer: type[ AbstractLinearTransformInitializer] = EyeLinearTransformInitializer - omega_initializer_kwargs: dict = field(default_factory=lambda: dict()) + omega_initializer_kwargs: dict = field(default_factory=dict) # Steps # ---------------------------------------------------------------------------------------------- diff --git a/prototorch/models/architectures/optimization.py b/prototorch/models/architectures/optimization.py index 481809f..a863835 100644 --- a/prototorch/models/architectures/optimization.py +++ b/prototorch/models/architectures/optimization.py @@ -46,7 +46,7 @@ class MultipleLearningRateMixin(BaseYArchitecture): lr: The learning rate. Default: 0.1. optimizer: The optimizer to use. Default: torch.optim.Adam. """ - lr: dict = field(default_factory=lambda: dict()) + lr: dict = field(default_factory=dict) optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam # Hooks diff --git a/prototorch/models/library/gmlvq.py b/prototorch/models/library/gmlvq.py index 5261f84..5d6e0f5 100644 --- a/prototorch/models/library/gmlvq.py +++ b/prototorch/models/library/gmlvq.py @@ -41,7 +41,7 @@ class GMLVQ( comparison_args: Keyword arguments for the comparison function. Override Default: {}. """ comparison_fn: Callable = omega_distance - comparison_args: dict = field(default_factory=lambda: dict()) + comparison_args: dict = field(default_factory=dict) optimizer: type[torch.optim.Optimizer] = torch.optim.Adam lr: dict = field(default_factory=lambda: dict( diff --git a/setup.py b/setup.py index 373105c..231be7e 100644 --- a/setup.py +++ b/setup.py @@ -10,6 +10,8 @@ ProtoTorch models Plugin Package """ +from pathlib import Path + from pkg_resources import safe_name from setuptools import find_namespace_packages, setup @@ -18,8 +20,7 @@ PLUGIN_NAME = "models" PROJECT_URL = "https://github.com/si-cim/prototorch_models" DOWNLOAD_URL = "https://github.com/si-cim/prototorch_models.git" -with open("README.md", "r") as fh: - long_description = fh.read() +long_description = Path("README.md").read_text(encoding='utf8') INSTALL_REQUIRES = [ "prototorch>=0.7.3",