ci: add refurb to pre-commit config

This commit is contained in:
Alexander Engelsberger 2022-10-26 13:19:45 +02:00
parent 482044ec87
commit c547af728b
No known key found for this signature in database
5 changed files with 14 additions and 8 deletions

View File

@ -52,3 +52,8 @@ repos:
hooks:
- id: gitlint
args: [--contrib=CT1, --ignore=B6, --msg-filename]
- repo: https://github.com/dosisod/refurb
rev: v1.4.0
hooks:
- id: refurb

View File

@ -32,9 +32,9 @@ class SimpleComparisonMixin(BaseYArchitecture):
comparison_args: Keyword arguments for the comparison function. Default: {}.
"""
comparison_fn: Callable = euclidean_distance
comparison_args: dict = field(default_factory=lambda: dict())
comparison_args: dict = field(default_factory=dict)
comparison_parameters: dict = field(default_factory=lambda: dict())
comparison_parameters: dict = field(default_factory=dict)
# Steps
# ----------------------------------------------------------------------------------------------
@ -44,7 +44,7 @@ class SimpleComparisonMixin(BaseYArchitecture):
**hparams.comparison_args,
)
self.comparison_kwargs: dict[str, Tensor] = dict()
self.comparison_kwargs: dict[str, Tensor] = {}
def comparison(self, batch, components):
comp_tensor, _ = components
@ -86,7 +86,7 @@ class OmegaComparisonMixin(SimpleComparisonMixin):
latent_dim: int = 2
omega_initializer: type[
AbstractLinearTransformInitializer] = EyeLinearTransformInitializer
omega_initializer_kwargs: dict = field(default_factory=lambda: dict())
omega_initializer_kwargs: dict = field(default_factory=dict)
# Steps
# ----------------------------------------------------------------------------------------------

View File

@ -46,7 +46,7 @@ class MultipleLearningRateMixin(BaseYArchitecture):
lr: The learning rate. Default: 0.1.
optimizer: The optimizer to use. Default: torch.optim.Adam.
"""
lr: dict = field(default_factory=lambda: dict())
lr: dict = field(default_factory=dict)
optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam
# Hooks

View File

@ -41,7 +41,7 @@ class GMLVQ(
comparison_args: Keyword arguments for the comparison function. Override Default: {}.
"""
comparison_fn: Callable = omega_distance
comparison_args: dict = field(default_factory=lambda: dict())
comparison_args: dict = field(default_factory=dict)
optimizer: type[torch.optim.Optimizer] = torch.optim.Adam
lr: dict = field(default_factory=lambda: dict(

View File

@ -10,6 +10,8 @@
ProtoTorch models Plugin Package
"""
from pathlib import Path
from pkg_resources import safe_name
from setuptools import find_namespace_packages, setup
@ -18,8 +20,7 @@ PLUGIN_NAME = "models"
PROJECT_URL = "https://github.com/si-cim/prototorch_models"
DOWNLOAD_URL = "https://github.com/si-cim/prototorch_models.git"
with open("README.md", "r") as fh:
long_description = fh.read()
long_description = Path("README.md").read_text(encoding='utf8')
INSTALL_REQUIRES = [
"prototorch>=0.7.3",