[QA] Add more pre commit checks

This commit is contained in:
Alexander Engelsberger 2021-06-16 13:46:09 +02:00 committed by Alexander Engelsberger
parent d0ae94f2af
commit 11cfa79746
17 changed files with 66 additions and 38 deletions

View File

@ -9,12 +9,24 @@ repos:
- id: check-yaml - id: check-yaml
- id: check-added-large-files - id: check-added-large-files
- id: check-ast - id: check-ast
- id: check-byte-order-marker
- id: check-case-conflict - id: check-case-conflict
#- repo: https://github.com/pre-commit/mirrors-mypy
# rev: 'v0.902'
# hooks: - repo: https://github.com/myint/autoflake
# - id: mypy rev: v1.4
hooks:
- id: autoflake
- repo: http://github.com/PyCQA/isort
rev: 5.8.0
hooks:
- id: isort
- repo: https://github.com/pre-commit/mirrors-mypy
rev: 'v0.902'
hooks:
- id: mypy
additional_dependencies: [types-pkg_resources]
- repo: https://github.com/pre-commit/mirrors-yapf - repo: https://github.com/pre-commit/mirrors-yapf
rev: 'v0.31.0' # Use the sha / tag you want to point at rev: 'v0.31.0' # Use the sha / tag you want to point at

View File

@ -68,7 +68,7 @@ master_doc = "index"
# List of patterns, relative to source directory, that match files and # List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files. # directories to ignore when looking for source files.
# This pattern also affects html_static_path and html_extra_path. # This pattern also affects html_static_path and html_extra_path.
exclude_patterns = [] #exclude_patterns = []
# The name of the Pygments (syntax highlighting) style to use. Choose from: # The name of the Pygments (syntax highlighting) style to use. Choose from:
# ["default", "emacs", "friendly", "colorful", "autumn", "murphy", "manni", # ["default", "emacs", "friendly", "colorful", "autumn", "murphy", "manni",
@ -124,20 +124,20 @@ htmlhelp_basename = "protoflowdoc"
# -- Options for LaTeX output --------------------------------------------- # -- Options for LaTeX output ---------------------------------------------
latex_elements = { #latex_elements = {
# The paper size ("letterpaper" or "a4paper"). # The paper size ("letterpaper" or "a4paper").
# #
# "papersize": "letterpaper", # "papersize": "letterpaper",
# The font size ("10pt", "11pt" or "12pt"). # The font size ("10pt", "11pt" or "12pt").
# #
# "pointsize": "10pt", # "pointsize": "10pt",
# Additional stuff for the LaTeX preamble. # Additional stuff for the LaTeX preamble.
# #
# "preamble": "", # "preamble": "",
# Latex figure (float) alignment # Latex figure (float) alignment
# #
# "figure_align": "htbp", # "figure_align": "htbp",
} #}
# Grouping the document tree into LaTeX files. List of tuples # Grouping the document tree into LaTeX files. List of tuples
# (source start file, target name, title, # (source start file, target name, title,

View File

@ -3,13 +3,14 @@
import numpy as np import numpy as np
import torch import torch
from matplotlib import pyplot as plt from matplotlib import pyplot as plt
from sklearn.datasets import load_iris
from sklearn.preprocessing import StandardScaler
from torchinfo import summary
from prototorch.components import LabeledComponents, StratifiedMeanInitializer from prototorch.components import LabeledComponents, StratifiedMeanInitializer
from prototorch.functions.competitions import wtac from prototorch.functions.competitions import wtac
from prototorch.functions.distances import euclidean_distance from prototorch.functions.distances import euclidean_distance
from prototorch.modules.losses import GLVQLoss from prototorch.modules.losses import GLVQLoss
from sklearn.datasets import load_iris
from sklearn.preprocessing import StandardScaler
from torchinfo import summary
# Prepare and preprocess the data # Prepare and preprocess the data
scaler = StandardScaler() scaler = StandardScaler()

View File

@ -2,12 +2,13 @@
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import torch import torch
from torch.utils.data import DataLoader
from prototorch.components import LabeledComponents, StratifiedMeanInitializer from prototorch.components import LabeledComponents, StratifiedMeanInitializer
from prototorch.datasets.tecator import Tecator from prototorch.datasets.tecator import Tecator
from prototorch.functions.distances import sed from prototorch.functions.distances import sed
from prototorch.modules.losses import GLVQLoss from prototorch.modules.losses import GLVQLoss
from prototorch.utils.colors import get_legend_handles from prototorch.utils.colors import get_legend_handles
from torch.utils.data import DataLoader
# Prepare the dataset and dataloader # Prepare the dataset and dataloader
train_data = Tecator(root="./artifacts", train=True) train_data = Tecator(root="./artifacts", train=True)

View File

@ -12,10 +12,11 @@ import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import torchvision import torchvision
from torchvision import transforms
from prototorch.functions.helper import calculate_prototype_accuracy from prototorch.functions.helper import calculate_prototype_accuracy
from prototorch.modules.losses import GLVQLoss from prototorch.modules.losses import GLVQLoss
from prototorch.modules.models import GTLVQ from prototorch.modules.models import GTLVQ
from torchvision import transforms
# Parameters and options # Parameters and options
num_epochs = 50 num_epochs = 50

View File

@ -3,13 +3,14 @@
import numpy as np import numpy as np
import torch import torch
from matplotlib import pyplot as plt from matplotlib import pyplot as plt
from prototorch.components import LabeledComponents, StratifiedMeanInitializer
from prototorch.functions.competitions import stratified_min
from prototorch.functions.distances import lomega_distance
from prototorch.modules.losses import GLVQLoss
from sklearn.datasets import load_iris from sklearn.datasets import load_iris
from sklearn.metrics import accuracy_score from sklearn.metrics import accuracy_score
from prototorch.components import LabeledComponents, StratifiedMeanInitializer
from prototorch.functions.distances import lomega_distance
from prototorch.functions.pooling import stratified_min_pooling
from prototorch.modules.losses import GLVQLoss
# Prepare training data # Prepare training data
x_train, y_train = load_iris(True) x_train, y_train = load_iris(True)
x_train = x_train[:, [0, 2]] x_train = x_train[:, [0, 2]]
@ -55,7 +56,8 @@ for epoch in range(100):
# Compute loss # Compute loss
dis, plabels = model(x_in) dis, plabels = model(x_in)
loss = criterion([dis, plabels], y_in) loss = criterion([dis, plabels], y_in)
y_pred = np.argmin(stratified_min(dis, plabels).detach().numpy(), axis=1) y_pred = np.argmin(stratified_min_pooling(dis, plabels).detach().numpy(),
axis=1)
acc = accuracy_score(y_train, y_pred) acc = accuracy_score(y_train, y_pred)
log_string = f"Epoch: {epoch + 1:03d} Loss: {loss.item():05.02f} " log_string = f"Epoch: {epoch + 1:03d} Loss: {loss.item():05.02f} "
log_string += f"Acc: {acc * 100:05.02f}%" log_string += f"Acc: {acc * 100:05.02f}%"
@ -96,7 +98,8 @@ for epoch in range(100):
mesh_input = np.c_[xx.ravel(), yy.ravel()] mesh_input = np.c_[xx.ravel(), yy.ravel()]
d, plabels = model(torch.Tensor(mesh_input)) d, plabels = model(torch.Tensor(mesh_input))
y_pred = np.argmin(stratified_min(d, plabels).detach().numpy(), axis=1) y_pred = np.argmin(stratified_min_pooling(d, plabels).detach().numpy(),
axis=1)
y_pred = y_pred.reshape(xx.shape) y_pred = y_pred.reshape(xx.shape)
# Plot voronoi regions # Plot voronoi regions

View File

@ -1,6 +1,7 @@
"""ProtoTorch package.""" """ProtoTorch package."""
import pkgutil import pkgutil
from typing import List
import pkg_resources import pkg_resources
@ -19,7 +20,7 @@ __all_core__ = [
] ]
# Plugin Loader # Plugin Loader
__path__ = pkgutil.extend_path(__path__, __name__) __path__: List[str] = pkgutil.extend_path(__path__, __name__)
def discover_plugins(): def discover_plugins():

View File

@ -3,12 +3,13 @@
import warnings import warnings
import torch import torch
from torch.nn.parameter import Parameter
from prototorch.components.initializers import (ClassAwareInitializer, from prototorch.components.initializers import (ClassAwareInitializer,
ComponentsInitializer, ComponentsInitializer,
EqualLabelsInitializer, EqualLabelsInitializer,
UnequalLabelsInitializer, UnequalLabelsInitializer,
ZeroReasoningsInitializer) ZeroReasoningsInitializer)
from torch.nn.parameter import Parameter
from .initializers import parse_data_arg from .initializers import parse_data_arg

View File

@ -8,11 +8,11 @@ URL:
import warnings import warnings
from typing import Sequence, Union from typing import Sequence, Union
from prototorch.datasets.abstract import NumpyDataset
from sklearn.datasets import (load_iris, make_blobs, make_circles, from sklearn.datasets import (load_iris, make_blobs, make_circles,
make_classification, make_moons) make_classification, make_moons)
from prototorch.datasets.abstract import NumpyDataset
class Iris(NumpyDataset): class Iris(NumpyDataset):
"""Iris Dataset by Ronald Fisher introduced in 1936. """Iris Dataset by Ronald Fisher introduced in 1936.

View File

@ -40,9 +40,10 @@ import os
import numpy as np import numpy as np
import torch import torch
from prototorch.datasets.abstract import ProtoDataset
from torchvision.datasets.utils import download_file_from_google_drive from torchvision.datasets.utils import download_file_from_google_drive
from prototorch.datasets.abstract import ProtoDataset
class Tecator(ProtoDataset): class Tecator(ProtoDataset):
""" """

View File

@ -2,6 +2,7 @@
import numpy as np import numpy as np
import torch import torch
from prototorch.functions.helper import (_check_shapes, _int_and_mixed_shape, from prototorch.functions.helper import (_check_shapes, _int_and_mixed_shape,
equal_int_shape, get_flat) equal_int_shape, get_flat)

View File

@ -1,6 +1,7 @@
"""ProtoTorch Competition Modules.""" """ProtoTorch Competition Modules."""
import torch import torch
from prototorch.functions.competitions import knnc, wtac from prototorch.functions.competitions import knnc, wtac

View File

@ -1,6 +1,7 @@
"""ProtoTorch losses.""" """ProtoTorch losses."""
import torch import torch
from prototorch.functions.activations import get_activation from prototorch.functions.activations import get_activation
from prototorch.functions.losses import glvq_loss from prototorch.functions.losses import glvq_loss

View File

@ -1,8 +1,9 @@
import torch import torch
from torch import nn
from prototorch.components import LabeledComponents, StratifiedMeanInitializer from prototorch.components import LabeledComponents, StratifiedMeanInitializer
from prototorch.functions.distances import euclidean_distance_matrix from prototorch.functions.distances import euclidean_distance_matrix
from prototorch.functions.normalization import orthogonalization from prototorch.functions.normalization import orthogonalization
from torch import nn
class GTLVQ(nn.Module): class GTLVQ(nn.Module):

View File

@ -1,6 +1,7 @@
"""ProtoTorch Pooling Modules.""" """ProtoTorch Pooling Modules."""
import torch import torch
from prototorch.functions.pooling import (stratified_max_pooling, from prototorch.functions.pooling import (stratified_max_pooling,
stratified_min_pooling, stratified_min_pooling,
stratified_prod_pooling, stratified_prod_pooling,

View File

@ -1,8 +1,9 @@
"""ProtoTorch components test suite.""" """ProtoTorch components test suite."""
import prototorch as pt
import torch import torch
import prototorch as pt
def test_labcomps_zeros_init(): def test_labcomps_zeros_init():
protos = torch.zeros(3, 2) protos = torch.zeros(3, 2)

View File

@ -4,6 +4,7 @@ import unittest
import numpy as np import numpy as np
import torch import torch
from prototorch.functions import (activations, competitions, distances, from prototorch.functions import (activations, competitions, distances,
initializers, losses, pooling) initializers, losses, pooling)