diff --git a/prototorch/__init__.py b/prototorch/__init__.py index c06d32c..3a4f631 100644 --- a/prototorch/__init__.py +++ b/prototorch/__init__.py @@ -1,8 +1,6 @@ """ProtoTorch package.""" -# ############################################# # Core Setup -# ############################################# __version__ = "0.3.0-dev0" __all_core__ = [ @@ -11,9 +9,9 @@ __all_core__ = [ "modules", ] -# ############################################# +from .datasets import * + # Plugin Loader -# ############################################# import pkgutil import pkg_resources diff --git a/prototorch/datasets/__init__.py b/prototorch/datasets/__init__.py index 08a9fa4..051d850 100644 --- a/prototorch/datasets/__init__.py +++ b/prototorch/datasets/__init__.py @@ -1,7 +1,11 @@ """ProtoTorch datasets.""" +from .abstract import NumpyDataset +from .spiral import Spiral from .tecator import Tecator __all__ = [ + "NumpyDataset", + "Spiral", "Tecator", ] diff --git a/prototorch/datasets/abstract.py b/prototorch/datasets/abstract.py index 7ff92aa..9d0cf27 100644 --- a/prototorch/datasets/abstract.py +++ b/prototorch/datasets/abstract.py @@ -13,6 +13,7 @@ import torch class NumpyDataset(torch.utils.data.TensorDataset): + """Create a PyTorch TensorDataset from NumPy arrays.""" def __init__(self, *arrays): tensors = [torch.Tensor(arr) for arr in arrays] super().__init__(*tensors)