[QA] Add more pre commit checks

This commit is contained in:
Alexander Engelsberger
2021-06-16 13:46:09 +02:00
committed by Alexander Engelsberger
parent d0ae94f2af
commit 11cfa79746
17 changed files with 66 additions and 38 deletions

View File

@@ -1,6 +1,7 @@
"""ProtoTorch package."""
import pkgutil
from typing import List
import pkg_resources
@@ -19,7 +20,7 @@ __all_core__ = [
]
# Plugin Loader
__path__ = pkgutil.extend_path(__path__, __name__)
__path__: List[str] = pkgutil.extend_path(__path__, __name__)
def discover_plugins():

View File

@@ -3,12 +3,13 @@
import warnings
import torch
from torch.nn.parameter import Parameter
from prototorch.components.initializers import (ClassAwareInitializer,
ComponentsInitializer,
EqualLabelsInitializer,
UnequalLabelsInitializer,
ZeroReasoningsInitializer)
from torch.nn.parameter import Parameter
from .initializers import parse_data_arg

View File

@@ -8,11 +8,11 @@ URL:
import warnings
from typing import Sequence, Union
from prototorch.datasets.abstract import NumpyDataset
from sklearn.datasets import (load_iris, make_blobs, make_circles,
make_classification, make_moons)
from prototorch.datasets.abstract import NumpyDataset
class Iris(NumpyDataset):
"""Iris Dataset by Ronald Fisher introduced in 1936.

View File

@@ -40,9 +40,10 @@ import os
import numpy as np
import torch
from prototorch.datasets.abstract import ProtoDataset
from torchvision.datasets.utils import download_file_from_google_drive
from prototorch.datasets.abstract import ProtoDataset
class Tecator(ProtoDataset):
"""

View File

@@ -2,6 +2,7 @@
import numpy as np
import torch
from prototorch.functions.helper import (_check_shapes, _int_and_mixed_shape,
equal_int_shape, get_flat)

View File

@@ -1,6 +1,7 @@
"""ProtoTorch Competition Modules."""
import torch
from prototorch.functions.competitions import knnc, wtac

View File

@@ -1,6 +1,7 @@
"""ProtoTorch losses."""
import torch
from prototorch.functions.activations import get_activation
from prototorch.functions.losses import glvq_loss

View File

@@ -1,8 +1,9 @@
import torch
from torch import nn
from prototorch.components import LabeledComponents, StratifiedMeanInitializer
from prototorch.functions.distances import euclidean_distance_matrix
from prototorch.functions.normalization import orthogonalization
from torch import nn
class GTLVQ(nn.Module):

View File

@@ -1,6 +1,7 @@
"""ProtoTorch Pooling Modules."""
import torch
from prototorch.functions.pooling import (stratified_max_pooling,
stratified_min_pooling,
stratified_prod_pooling,