ci: add refurb to pre-commit config

This commit is contained in:
Alexander Engelsberger 2022-10-26 13:19:45 +02:00
parent 482044ec87
commit c547af728b
No known key found for this signature in database
5 changed files with 14 additions and 8 deletions

View File

@ -52,3 +52,8 @@ repos:
hooks: hooks:
- id: gitlint - id: gitlint
args: [--contrib=CT1, --ignore=B6, --msg-filename] args: [--contrib=CT1, --ignore=B6, --msg-filename]
- repo: https://github.com/dosisod/refurb
rev: v1.4.0
hooks:
- id: refurb

View File

@ -32,9 +32,9 @@ class SimpleComparisonMixin(BaseYArchitecture):
comparison_args: Keyword arguments for the comparison function. Default: {}. comparison_args: Keyword arguments for the comparison function. Default: {}.
""" """
comparison_fn: Callable = euclidean_distance comparison_fn: Callable = euclidean_distance
comparison_args: dict = field(default_factory=lambda: dict()) comparison_args: dict = field(default_factory=dict)
comparison_parameters: dict = field(default_factory=lambda: dict()) comparison_parameters: dict = field(default_factory=dict)
# Steps # Steps
# ---------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------
@ -44,7 +44,7 @@ class SimpleComparisonMixin(BaseYArchitecture):
**hparams.comparison_args, **hparams.comparison_args,
) )
self.comparison_kwargs: dict[str, Tensor] = dict() self.comparison_kwargs: dict[str, Tensor] = {}
def comparison(self, batch, components): def comparison(self, batch, components):
comp_tensor, _ = components comp_tensor, _ = components
@ -86,7 +86,7 @@ class OmegaComparisonMixin(SimpleComparisonMixin):
latent_dim: int = 2 latent_dim: int = 2
omega_initializer: type[ omega_initializer: type[
AbstractLinearTransformInitializer] = EyeLinearTransformInitializer AbstractLinearTransformInitializer] = EyeLinearTransformInitializer
omega_initializer_kwargs: dict = field(default_factory=lambda: dict()) omega_initializer_kwargs: dict = field(default_factory=dict)
# Steps # Steps
# ---------------------------------------------------------------------------------------------- # ----------------------------------------------------------------------------------------------

View File

@ -46,7 +46,7 @@ class MultipleLearningRateMixin(BaseYArchitecture):
lr: The learning rate. Default: 0.1. lr: The learning rate. Default: 0.1.
optimizer: The optimizer to use. Default: torch.optim.Adam. optimizer: The optimizer to use. Default: torch.optim.Adam.
""" """
lr: dict = field(default_factory=lambda: dict()) lr: dict = field(default_factory=dict)
optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam
# Hooks # Hooks

View File

@ -41,7 +41,7 @@ class GMLVQ(
comparison_args: Keyword arguments for the comparison function. Override Default: {}. comparison_args: Keyword arguments for the comparison function. Override Default: {}.
""" """
comparison_fn: Callable = omega_distance comparison_fn: Callable = omega_distance
comparison_args: dict = field(default_factory=lambda: dict()) comparison_args: dict = field(default_factory=dict)
optimizer: type[torch.optim.Optimizer] = torch.optim.Adam optimizer: type[torch.optim.Optimizer] = torch.optim.Adam
lr: dict = field(default_factory=lambda: dict( lr: dict = field(default_factory=lambda: dict(

View File

@ -10,6 +10,8 @@
ProtoTorch models Plugin Package ProtoTorch models Plugin Package
""" """
from pathlib import Path
from pkg_resources import safe_name from pkg_resources import safe_name
from setuptools import find_namespace_packages, setup from setuptools import find_namespace_packages, setup
@ -18,8 +20,7 @@ PLUGIN_NAME = "models"
PROJECT_URL = "https://github.com/si-cim/prototorch_models" PROJECT_URL = "https://github.com/si-cim/prototorch_models"
DOWNLOAD_URL = "https://github.com/si-cim/prototorch_models.git" DOWNLOAD_URL = "https://github.com/si-cim/prototorch_models.git"
with open("README.md", "r") as fh: long_description = Path("README.md").read_text(encoding='utf8')
long_description = fh.read()
INSTALL_REQUIRES = [ INSTALL_REQUIRES = [
"prototorch>=0.7.3", "prototorch>=0.7.3",