From 47b4b9bcb10764857b4dad98be87cc129867f523 Mon Sep 17 00:00:00 2001 From: Jensun Ravichandran Date: Fri, 7 May 2021 15:18:33 +0200 Subject: [PATCH] Expose prototorch.datasets --- prototorch/__init__.py | 6 ++---- prototorch/datasets/__init__.py | 4 ++++ prototorch/datasets/abstract.py | 1 + 3 files changed, 7 insertions(+), 4 deletions(-) diff --git a/prototorch/__init__.py b/prototorch/__init__.py index c06d32c..3a4f631 100644 --- a/prototorch/__init__.py +++ b/prototorch/__init__.py @@ -1,8 +1,6 @@ """ProtoTorch package.""" -# ############################################# # Core Setup -# ############################################# __version__ = "0.3.0-dev0" __all_core__ = [ @@ -11,9 +9,9 @@ __all_core__ = [ "modules", ] -# ############################################# +from .datasets import * + # Plugin Loader -# ############################################# import pkgutil import pkg_resources diff --git a/prototorch/datasets/__init__.py b/prototorch/datasets/__init__.py index 08a9fa4..051d850 100644 --- a/prototorch/datasets/__init__.py +++ b/prototorch/datasets/__init__.py @@ -1,7 +1,11 @@ """ProtoTorch datasets.""" +from .abstract import NumpyDataset +from .spiral import Spiral from .tecator import Tecator __all__ = [ + "NumpyDataset", + "Spiral", "Tecator", ] diff --git a/prototorch/datasets/abstract.py b/prototorch/datasets/abstract.py index 7ff92aa..9d0cf27 100644 --- a/prototorch/datasets/abstract.py +++ b/prototorch/datasets/abstract.py @@ -13,6 +13,7 @@ import torch class NumpyDataset(torch.utils.data.TensorDataset): + """Create a PyTorch TensorDataset from NumPy arrays.""" def __init__(self, *arrays): tensors = [torch.Tensor(arr) for arr in arrays] super().__init__(*tensors)