Add plugin loader.

This commit is contained in:
Alexander Engelsberger 2021-04-13 12:36:22 +02:00
parent 429570323e
commit 5b2ab34232
2 changed files with 112 additions and 43 deletions

View File

@ -1,11 +1,43 @@
"""ProtoTorch package."""
__version__ = '0.1.1-rc0'
# #############################################
# Core Setup
# #############################################
__version_core__ = "0.2.0-dev0"
from prototorch import datasets, functions, modules
__all__ = [
'datasets',
'functions',
'modules',
__all_core__ = [
"datasets",
"functions",
"modules",
]
# #############################################
# Plugin Loader
# #############################################
import pkg_resources
def discover_plugins():
return {
entry_point.name: entry_point.load()
for entry_point in pkg_resources.iter_entry_points("prototorch.plugins")
}
discovered_plugins = discover_plugins()
locals().update(discovered_plugins)
# Generate combines __version__ and __all__
__version_plugins__ = "\n".join(
[
"- " + name + ": v" + plugin.__version__
for name, plugin in discovered_plugins.items()
]
)
if __version_plugins__ != "":
__version_plugins__ = "\nPlugins: \n" + __version_plugins__
__version__ = "core: v" + __version_core__ + __version_plugins__
__all__ = __all_core__ + list(discovered_plugins.keys())

View File

@ -1,8 +1,42 @@
"""Install ProtoTorch."""
"""
_____ _ _______ _
| __ \ | | |__ __| | |
| |__) | __ ___ | |_ ___ | | ___ _ __ ___| |__
| ___/ '__/ _ \| __/ _ \| |/ _ \| '__/ __| '_ \
| | | | | (_) | || (_) | | (_) | | | (__| | | |
|_| |_| \___/ \__\___/|_|\___/|_| \___|_| |_|
ProtoTorch Core Package
"""
from setuptools import setup
from setuptools import find_packages
from pkg_resources import safe_name
import ast
import importlib.util
PKG_DIR = "prototorch"
def find_version():
"""Return value of __version__.
Reference: https://stackoverflow.com/a/42269185/
"""
file_path = importlib.util.find_spec(PKG_DIR).origin
with open(file_path) as file_obj:
root_node = ast.parse(file_obj.read())
for node in ast.walk(root_node):
if isinstance(node, ast.Assign):
if len(node.targets) == 1 and node.targets[0].id == "__version_core__":
return node.value.s
raise RuntimeError("Unable to find version string.")
version = find_version()
PROJECT_URL = "https://github.com/si-cim/prototorch"
DOWNLOAD_URL = "https://github.com/si-cim/prototorch.git"
@ -32,8 +66,9 @@ EXAMPLES = [
TESTS = ["pytest"]
ALL = DOCS + DATASETS + EXAMPLES + TESTS
setup(name="prototorch",
version="0.1.1-rc0",
setup(
name=safe_name(PKG_DIR),
version=version,
description="Highly extensible, GPU-supported "
"Learning Vector Quantization (LVQ) toolbox "
"built using PyTorch and its nn API.",
@ -68,4 +103,6 @@ setup(name="prototorch",
"Topic :: Software Development :: Libraries",
"Topic :: Software Development :: Libraries :: Python Modules",
],
packages=find_packages())
packages=find_packages(),
zip_safe=False,
)