[QA] Add more pre commit checks
This commit is contained in:
parent
d0ae94f2af
commit
11cfa79746
@ -9,12 +9,24 @@ repos:
|
|||||||
- id: check-yaml
|
- id: check-yaml
|
||||||
- id: check-added-large-files
|
- id: check-added-large-files
|
||||||
- id: check-ast
|
- id: check-ast
|
||||||
- id: check-byte-order-marker
|
|
||||||
- id: check-case-conflict
|
- id: check-case-conflict
|
||||||
#- repo: https://github.com/pre-commit/mirrors-mypy
|
|
||||||
# rev: 'v0.902'
|
|
||||||
# hooks:
|
- repo: https://github.com/myint/autoflake
|
||||||
# - id: mypy
|
rev: v1.4
|
||||||
|
hooks:
|
||||||
|
- id: autoflake
|
||||||
|
|
||||||
|
- repo: http://github.com/PyCQA/isort
|
||||||
|
rev: 5.8.0
|
||||||
|
hooks:
|
||||||
|
- id: isort
|
||||||
|
|
||||||
|
- repo: https://github.com/pre-commit/mirrors-mypy
|
||||||
|
rev: 'v0.902'
|
||||||
|
hooks:
|
||||||
|
- id: mypy
|
||||||
|
additional_dependencies: [types-pkg_resources]
|
||||||
|
|
||||||
- repo: https://github.com/pre-commit/mirrors-yapf
|
- repo: https://github.com/pre-commit/mirrors-yapf
|
||||||
rev: 'v0.31.0' # Use the sha / tag you want to point at
|
rev: 'v0.31.0' # Use the sha / tag you want to point at
|
||||||
|
@ -68,7 +68,7 @@ master_doc = "index"
|
|||||||
# List of patterns, relative to source directory, that match files and
|
# List of patterns, relative to source directory, that match files and
|
||||||
# directories to ignore when looking for source files.
|
# directories to ignore when looking for source files.
|
||||||
# This pattern also affects html_static_path and html_extra_path.
|
# This pattern also affects html_static_path and html_extra_path.
|
||||||
exclude_patterns = []
|
#exclude_patterns = []
|
||||||
|
|
||||||
# The name of the Pygments (syntax highlighting) style to use. Choose from:
|
# The name of the Pygments (syntax highlighting) style to use. Choose from:
|
||||||
# ["default", "emacs", "friendly", "colorful", "autumn", "murphy", "manni",
|
# ["default", "emacs", "friendly", "colorful", "autumn", "murphy", "manni",
|
||||||
@ -124,20 +124,20 @@ htmlhelp_basename = "protoflowdoc"
|
|||||||
|
|
||||||
# -- Options for LaTeX output ---------------------------------------------
|
# -- Options for LaTeX output ---------------------------------------------
|
||||||
|
|
||||||
latex_elements = {
|
#latex_elements = {
|
||||||
# The paper size ("letterpaper" or "a4paper").
|
# The paper size ("letterpaper" or "a4paper").
|
||||||
#
|
#
|
||||||
# "papersize": "letterpaper",
|
# "papersize": "letterpaper",
|
||||||
# The font size ("10pt", "11pt" or "12pt").
|
# The font size ("10pt", "11pt" or "12pt").
|
||||||
#
|
#
|
||||||
# "pointsize": "10pt",
|
# "pointsize": "10pt",
|
||||||
# Additional stuff for the LaTeX preamble.
|
# Additional stuff for the LaTeX preamble.
|
||||||
#
|
#
|
||||||
# "preamble": "",
|
# "preamble": "",
|
||||||
# Latex figure (float) alignment
|
# Latex figure (float) alignment
|
||||||
#
|
#
|
||||||
# "figure_align": "htbp",
|
# "figure_align": "htbp",
|
||||||
}
|
#}
|
||||||
|
|
||||||
# Grouping the document tree into LaTeX files. List of tuples
|
# Grouping the document tree into LaTeX files. List of tuples
|
||||||
# (source start file, target name, title,
|
# (source start file, target name, title,
|
||||||
|
@ -3,13 +3,14 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from matplotlib import pyplot as plt
|
from matplotlib import pyplot as plt
|
||||||
|
from sklearn.datasets import load_iris
|
||||||
|
from sklearn.preprocessing import StandardScaler
|
||||||
|
from torchinfo import summary
|
||||||
|
|
||||||
from prototorch.components import LabeledComponents, StratifiedMeanInitializer
|
from prototorch.components import LabeledComponents, StratifiedMeanInitializer
|
||||||
from prototorch.functions.competitions import wtac
|
from prototorch.functions.competitions import wtac
|
||||||
from prototorch.functions.distances import euclidean_distance
|
from prototorch.functions.distances import euclidean_distance
|
||||||
from prototorch.modules.losses import GLVQLoss
|
from prototorch.modules.losses import GLVQLoss
|
||||||
from sklearn.datasets import load_iris
|
|
||||||
from sklearn.preprocessing import StandardScaler
|
|
||||||
from torchinfo import summary
|
|
||||||
|
|
||||||
# Prepare and preprocess the data
|
# Prepare and preprocess the data
|
||||||
scaler = StandardScaler()
|
scaler = StandardScaler()
|
||||||
|
@ -2,12 +2,13 @@
|
|||||||
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import torch
|
import torch
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
from prototorch.components import LabeledComponents, StratifiedMeanInitializer
|
from prototorch.components import LabeledComponents, StratifiedMeanInitializer
|
||||||
from prototorch.datasets.tecator import Tecator
|
from prototorch.datasets.tecator import Tecator
|
||||||
from prototorch.functions.distances import sed
|
from prototorch.functions.distances import sed
|
||||||
from prototorch.modules.losses import GLVQLoss
|
from prototorch.modules.losses import GLVQLoss
|
||||||
from prototorch.utils.colors import get_legend_handles
|
from prototorch.utils.colors import get_legend_handles
|
||||||
from torch.utils.data import DataLoader
|
|
||||||
|
|
||||||
# Prepare the dataset and dataloader
|
# Prepare the dataset and dataloader
|
||||||
train_data = Tecator(root="./artifacts", train=True)
|
train_data = Tecator(root="./artifacts", train=True)
|
||||||
|
@ -12,10 +12,11 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torchvision
|
import torchvision
|
||||||
|
from torchvision import transforms
|
||||||
|
|
||||||
from prototorch.functions.helper import calculate_prototype_accuracy
|
from prototorch.functions.helper import calculate_prototype_accuracy
|
||||||
from prototorch.modules.losses import GLVQLoss
|
from prototorch.modules.losses import GLVQLoss
|
||||||
from prototorch.modules.models import GTLVQ
|
from prototorch.modules.models import GTLVQ
|
||||||
from torchvision import transforms
|
|
||||||
|
|
||||||
# Parameters and options
|
# Parameters and options
|
||||||
num_epochs = 50
|
num_epochs = 50
|
||||||
|
@ -3,13 +3,14 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from matplotlib import pyplot as plt
|
from matplotlib import pyplot as plt
|
||||||
from prototorch.components import LabeledComponents, StratifiedMeanInitializer
|
|
||||||
from prototorch.functions.competitions import stratified_min
|
|
||||||
from prototorch.functions.distances import lomega_distance
|
|
||||||
from prototorch.modules.losses import GLVQLoss
|
|
||||||
from sklearn.datasets import load_iris
|
from sklearn.datasets import load_iris
|
||||||
from sklearn.metrics import accuracy_score
|
from sklearn.metrics import accuracy_score
|
||||||
|
|
||||||
|
from prototorch.components import LabeledComponents, StratifiedMeanInitializer
|
||||||
|
from prototorch.functions.distances import lomega_distance
|
||||||
|
from prototorch.functions.pooling import stratified_min_pooling
|
||||||
|
from prototorch.modules.losses import GLVQLoss
|
||||||
|
|
||||||
# Prepare training data
|
# Prepare training data
|
||||||
x_train, y_train = load_iris(True)
|
x_train, y_train = load_iris(True)
|
||||||
x_train = x_train[:, [0, 2]]
|
x_train = x_train[:, [0, 2]]
|
||||||
@ -55,7 +56,8 @@ for epoch in range(100):
|
|||||||
# Compute loss
|
# Compute loss
|
||||||
dis, plabels = model(x_in)
|
dis, plabels = model(x_in)
|
||||||
loss = criterion([dis, plabels], y_in)
|
loss = criterion([dis, plabels], y_in)
|
||||||
y_pred = np.argmin(stratified_min(dis, plabels).detach().numpy(), axis=1)
|
y_pred = np.argmin(stratified_min_pooling(dis, plabels).detach().numpy(),
|
||||||
|
axis=1)
|
||||||
acc = accuracy_score(y_train, y_pred)
|
acc = accuracy_score(y_train, y_pred)
|
||||||
log_string = f"Epoch: {epoch + 1:03d} Loss: {loss.item():05.02f} "
|
log_string = f"Epoch: {epoch + 1:03d} Loss: {loss.item():05.02f} "
|
||||||
log_string += f"Acc: {acc * 100:05.02f}%"
|
log_string += f"Acc: {acc * 100:05.02f}%"
|
||||||
@ -96,7 +98,8 @@ for epoch in range(100):
|
|||||||
mesh_input = np.c_[xx.ravel(), yy.ravel()]
|
mesh_input = np.c_[xx.ravel(), yy.ravel()]
|
||||||
|
|
||||||
d, plabels = model(torch.Tensor(mesh_input))
|
d, plabels = model(torch.Tensor(mesh_input))
|
||||||
y_pred = np.argmin(stratified_min(d, plabels).detach().numpy(), axis=1)
|
y_pred = np.argmin(stratified_min_pooling(d, plabels).detach().numpy(),
|
||||||
|
axis=1)
|
||||||
y_pred = y_pred.reshape(xx.shape)
|
y_pred = y_pred.reshape(xx.shape)
|
||||||
|
|
||||||
# Plot voronoi regions
|
# Plot voronoi regions
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
"""ProtoTorch package."""
|
"""ProtoTorch package."""
|
||||||
|
|
||||||
import pkgutil
|
import pkgutil
|
||||||
|
from typing import List
|
||||||
|
|
||||||
import pkg_resources
|
import pkg_resources
|
||||||
|
|
||||||
@ -19,7 +20,7 @@ __all_core__ = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
# Plugin Loader
|
# Plugin Loader
|
||||||
__path__ = pkgutil.extend_path(__path__, __name__)
|
__path__: List[str] = pkgutil.extend_path(__path__, __name__)
|
||||||
|
|
||||||
|
|
||||||
def discover_plugins():
|
def discover_plugins():
|
||||||
|
@ -3,12 +3,13 @@
|
|||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from torch.nn.parameter import Parameter
|
||||||
|
|
||||||
from prototorch.components.initializers import (ClassAwareInitializer,
|
from prototorch.components.initializers import (ClassAwareInitializer,
|
||||||
ComponentsInitializer,
|
ComponentsInitializer,
|
||||||
EqualLabelsInitializer,
|
EqualLabelsInitializer,
|
||||||
UnequalLabelsInitializer,
|
UnequalLabelsInitializer,
|
||||||
ZeroReasoningsInitializer)
|
ZeroReasoningsInitializer)
|
||||||
from torch.nn.parameter import Parameter
|
|
||||||
|
|
||||||
from .initializers import parse_data_arg
|
from .initializers import parse_data_arg
|
||||||
|
|
||||||
|
@ -8,11 +8,11 @@ URL:
|
|||||||
import warnings
|
import warnings
|
||||||
from typing import Sequence, Union
|
from typing import Sequence, Union
|
||||||
|
|
||||||
from prototorch.datasets.abstract import NumpyDataset
|
|
||||||
|
|
||||||
from sklearn.datasets import (load_iris, make_blobs, make_circles,
|
from sklearn.datasets import (load_iris, make_blobs, make_circles,
|
||||||
make_classification, make_moons)
|
make_classification, make_moons)
|
||||||
|
|
||||||
|
from prototorch.datasets.abstract import NumpyDataset
|
||||||
|
|
||||||
|
|
||||||
class Iris(NumpyDataset):
|
class Iris(NumpyDataset):
|
||||||
"""Iris Dataset by Ronald Fisher introduced in 1936.
|
"""Iris Dataset by Ronald Fisher introduced in 1936.
|
||||||
|
@ -40,9 +40,10 @@ import os
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from prototorch.datasets.abstract import ProtoDataset
|
|
||||||
from torchvision.datasets.utils import download_file_from_google_drive
|
from torchvision.datasets.utils import download_file_from_google_drive
|
||||||
|
|
||||||
|
from prototorch.datasets.abstract import ProtoDataset
|
||||||
|
|
||||||
|
|
||||||
class Tecator(ProtoDataset):
|
class Tecator(ProtoDataset):
|
||||||
"""
|
"""
|
||||||
|
@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from prototorch.functions.helper import (_check_shapes, _int_and_mixed_shape,
|
from prototorch.functions.helper import (_check_shapes, _int_and_mixed_shape,
|
||||||
equal_int_shape, get_flat)
|
equal_int_shape, get_flat)
|
||||||
|
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
"""ProtoTorch Competition Modules."""
|
"""ProtoTorch Competition Modules."""
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from prototorch.functions.competitions import knnc, wtac
|
from prototorch.functions.competitions import knnc, wtac
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
"""ProtoTorch losses."""
|
"""ProtoTorch losses."""
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from prototorch.functions.activations import get_activation
|
from prototorch.functions.activations import get_activation
|
||||||
from prototorch.functions.losses import glvq_loss
|
from prototorch.functions.losses import glvq_loss
|
||||||
|
|
||||||
|
@ -1,8 +1,9 @@
|
|||||||
import torch
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
from prototorch.components import LabeledComponents, StratifiedMeanInitializer
|
from prototorch.components import LabeledComponents, StratifiedMeanInitializer
|
||||||
from prototorch.functions.distances import euclidean_distance_matrix
|
from prototorch.functions.distances import euclidean_distance_matrix
|
||||||
from prototorch.functions.normalization import orthogonalization
|
from prototorch.functions.normalization import orthogonalization
|
||||||
from torch import nn
|
|
||||||
|
|
||||||
|
|
||||||
class GTLVQ(nn.Module):
|
class GTLVQ(nn.Module):
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
"""ProtoTorch Pooling Modules."""
|
"""ProtoTorch Pooling Modules."""
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from prototorch.functions.pooling import (stratified_max_pooling,
|
from prototorch.functions.pooling import (stratified_max_pooling,
|
||||||
stratified_min_pooling,
|
stratified_min_pooling,
|
||||||
stratified_prod_pooling,
|
stratified_prod_pooling,
|
||||||
|
@ -1,8 +1,9 @@
|
|||||||
"""ProtoTorch components test suite."""
|
"""ProtoTorch components test suite."""
|
||||||
|
|
||||||
import prototorch as pt
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
import prototorch as pt
|
||||||
|
|
||||||
|
|
||||||
def test_labcomps_zeros_init():
|
def test_labcomps_zeros_init():
|
||||||
protos = torch.zeros(3, 2)
|
protos = torch.zeros(3, 2)
|
||||||
|
@ -4,6 +4,7 @@ import unittest
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from prototorch.functions import (activations, competitions, distances,
|
from prototorch.functions import (activations, competitions, distances,
|
||||||
initializers, losses, pooling)
|
initializers, losses, pooling)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user