ci: add refurb to pre-commit config
This commit is contained in:
parent
482044ec87
commit
c547af728b
@ -52,3 +52,8 @@ repos:
|
||||
hooks:
|
||||
- id: gitlint
|
||||
args: [--contrib=CT1, --ignore=B6, --msg-filename]
|
||||
|
||||
- repo: https://github.com/dosisod/refurb
|
||||
rev: v1.4.0
|
||||
hooks:
|
||||
- id: refurb
|
||||
|
@ -32,9 +32,9 @@ class SimpleComparisonMixin(BaseYArchitecture):
|
||||
comparison_args: Keyword arguments for the comparison function. Default: {}.
|
||||
"""
|
||||
comparison_fn: Callable = euclidean_distance
|
||||
comparison_args: dict = field(default_factory=lambda: dict())
|
||||
comparison_args: dict = field(default_factory=dict)
|
||||
|
||||
comparison_parameters: dict = field(default_factory=lambda: dict())
|
||||
comparison_parameters: dict = field(default_factory=dict)
|
||||
|
||||
# Steps
|
||||
# ----------------------------------------------------------------------------------------------
|
||||
@ -44,7 +44,7 @@ class SimpleComparisonMixin(BaseYArchitecture):
|
||||
**hparams.comparison_args,
|
||||
)
|
||||
|
||||
self.comparison_kwargs: dict[str, Tensor] = dict()
|
||||
self.comparison_kwargs: dict[str, Tensor] = {}
|
||||
|
||||
def comparison(self, batch, components):
|
||||
comp_tensor, _ = components
|
||||
@ -86,7 +86,7 @@ class OmegaComparisonMixin(SimpleComparisonMixin):
|
||||
latent_dim: int = 2
|
||||
omega_initializer: type[
|
||||
AbstractLinearTransformInitializer] = EyeLinearTransformInitializer
|
||||
omega_initializer_kwargs: dict = field(default_factory=lambda: dict())
|
||||
omega_initializer_kwargs: dict = field(default_factory=dict)
|
||||
|
||||
# Steps
|
||||
# ----------------------------------------------------------------------------------------------
|
||||
|
@ -46,7 +46,7 @@ class MultipleLearningRateMixin(BaseYArchitecture):
|
||||
lr: The learning rate. Default: 0.1.
|
||||
optimizer: The optimizer to use. Default: torch.optim.Adam.
|
||||
"""
|
||||
lr: dict = field(default_factory=lambda: dict())
|
||||
lr: dict = field(default_factory=dict)
|
||||
optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam
|
||||
|
||||
# Hooks
|
||||
|
@ -41,7 +41,7 @@ class GMLVQ(
|
||||
comparison_args: Keyword arguments for the comparison function. Override Default: {}.
|
||||
"""
|
||||
comparison_fn: Callable = omega_distance
|
||||
comparison_args: dict = field(default_factory=lambda: dict())
|
||||
comparison_args: dict = field(default_factory=dict)
|
||||
optimizer: type[torch.optim.Optimizer] = torch.optim.Adam
|
||||
|
||||
lr: dict = field(default_factory=lambda: dict(
|
||||
|
5
setup.py
5
setup.py
@ -10,6 +10,8 @@
|
||||
|
||||
ProtoTorch models Plugin Package
|
||||
"""
|
||||
from pathlib import Path
|
||||
|
||||
from pkg_resources import safe_name
|
||||
from setuptools import find_namespace_packages, setup
|
||||
|
||||
@ -18,8 +20,7 @@ PLUGIN_NAME = "models"
|
||||
PROJECT_URL = "https://github.com/si-cim/prototorch_models"
|
||||
DOWNLOAD_URL = "https://github.com/si-cim/prototorch_models.git"
|
||||
|
||||
with open("README.md", "r") as fh:
|
||||
long_description = fh.read()
|
||||
long_description = Path("README.md").read_text(encoding='utf8')
|
||||
|
||||
INSTALL_REQUIRES = [
|
||||
"prototorch>=0.7.3",
|
||||
|
Loading…
Reference in New Issue
Block a user