ci: add refurb to pre-commit config
This commit is contained in:
parent
482044ec87
commit
c547af728b
@ -52,3 +52,8 @@ repos:
|
|||||||
hooks:
|
hooks:
|
||||||
- id: gitlint
|
- id: gitlint
|
||||||
args: [--contrib=CT1, --ignore=B6, --msg-filename]
|
args: [--contrib=CT1, --ignore=B6, --msg-filename]
|
||||||
|
|
||||||
|
- repo: https://github.com/dosisod/refurb
|
||||||
|
rev: v1.4.0
|
||||||
|
hooks:
|
||||||
|
- id: refurb
|
||||||
|
@ -32,9 +32,9 @@ class SimpleComparisonMixin(BaseYArchitecture):
|
|||||||
comparison_args: Keyword arguments for the comparison function. Default: {}.
|
comparison_args: Keyword arguments for the comparison function. Default: {}.
|
||||||
"""
|
"""
|
||||||
comparison_fn: Callable = euclidean_distance
|
comparison_fn: Callable = euclidean_distance
|
||||||
comparison_args: dict = field(default_factory=lambda: dict())
|
comparison_args: dict = field(default_factory=dict)
|
||||||
|
|
||||||
comparison_parameters: dict = field(default_factory=lambda: dict())
|
comparison_parameters: dict = field(default_factory=dict)
|
||||||
|
|
||||||
# Steps
|
# Steps
|
||||||
# ----------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------
|
||||||
@ -44,7 +44,7 @@ class SimpleComparisonMixin(BaseYArchitecture):
|
|||||||
**hparams.comparison_args,
|
**hparams.comparison_args,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.comparison_kwargs: dict[str, Tensor] = dict()
|
self.comparison_kwargs: dict[str, Tensor] = {}
|
||||||
|
|
||||||
def comparison(self, batch, components):
|
def comparison(self, batch, components):
|
||||||
comp_tensor, _ = components
|
comp_tensor, _ = components
|
||||||
@ -86,7 +86,7 @@ class OmegaComparisonMixin(SimpleComparisonMixin):
|
|||||||
latent_dim: int = 2
|
latent_dim: int = 2
|
||||||
omega_initializer: type[
|
omega_initializer: type[
|
||||||
AbstractLinearTransformInitializer] = EyeLinearTransformInitializer
|
AbstractLinearTransformInitializer] = EyeLinearTransformInitializer
|
||||||
omega_initializer_kwargs: dict = field(default_factory=lambda: dict())
|
omega_initializer_kwargs: dict = field(default_factory=dict)
|
||||||
|
|
||||||
# Steps
|
# Steps
|
||||||
# ----------------------------------------------------------------------------------------------
|
# ----------------------------------------------------------------------------------------------
|
||||||
|
@ -46,7 +46,7 @@ class MultipleLearningRateMixin(BaseYArchitecture):
|
|||||||
lr: The learning rate. Default: 0.1.
|
lr: The learning rate. Default: 0.1.
|
||||||
optimizer: The optimizer to use. Default: torch.optim.Adam.
|
optimizer: The optimizer to use. Default: torch.optim.Adam.
|
||||||
"""
|
"""
|
||||||
lr: dict = field(default_factory=lambda: dict())
|
lr: dict = field(default_factory=dict)
|
||||||
optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam
|
optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam
|
||||||
|
|
||||||
# Hooks
|
# Hooks
|
||||||
|
@ -41,7 +41,7 @@ class GMLVQ(
|
|||||||
comparison_args: Keyword arguments for the comparison function. Override Default: {}.
|
comparison_args: Keyword arguments for the comparison function. Override Default: {}.
|
||||||
"""
|
"""
|
||||||
comparison_fn: Callable = omega_distance
|
comparison_fn: Callable = omega_distance
|
||||||
comparison_args: dict = field(default_factory=lambda: dict())
|
comparison_args: dict = field(default_factory=dict)
|
||||||
optimizer: type[torch.optim.Optimizer] = torch.optim.Adam
|
optimizer: type[torch.optim.Optimizer] = torch.optim.Adam
|
||||||
|
|
||||||
lr: dict = field(default_factory=lambda: dict(
|
lr: dict = field(default_factory=lambda: dict(
|
||||||
|
5
setup.py
5
setup.py
@ -10,6 +10,8 @@
|
|||||||
|
|
||||||
ProtoTorch models Plugin Package
|
ProtoTorch models Plugin Package
|
||||||
"""
|
"""
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
from pkg_resources import safe_name
|
from pkg_resources import safe_name
|
||||||
from setuptools import find_namespace_packages, setup
|
from setuptools import find_namespace_packages, setup
|
||||||
|
|
||||||
@ -18,8 +20,7 @@ PLUGIN_NAME = "models"
|
|||||||
PROJECT_URL = "https://github.com/si-cim/prototorch_models"
|
PROJECT_URL = "https://github.com/si-cim/prototorch_models"
|
||||||
DOWNLOAD_URL = "https://github.com/si-cim/prototorch_models.git"
|
DOWNLOAD_URL = "https://github.com/si-cim/prototorch_models.git"
|
||||||
|
|
||||||
with open("README.md", "r") as fh:
|
long_description = Path("README.md").read_text(encoding='utf8')
|
||||||
long_description = fh.read()
|
|
||||||
|
|
||||||
INSTALL_REQUIRES = [
|
INSTALL_REQUIRES = [
|
||||||
"prototorch>=0.7.3",
|
"prototorch>=0.7.3",
|
||||||
|
Loading…
Reference in New Issue
Block a user