Expose prototorch.datasets

This commit is contained in:
Jensun Ravichandran 2021-05-07 15:18:33 +02:00
parent 19475d7e2b
commit 47b4b9bcb1
3 changed files with 7 additions and 4 deletions

View File

@ -1,8 +1,6 @@
"""ProtoTorch package.""" """ProtoTorch package."""
# #############################################
# Core Setup # Core Setup
# #############################################
__version__ = "0.3.0-dev0" __version__ = "0.3.0-dev0"
__all_core__ = [ __all_core__ = [
@ -11,9 +9,9 @@ __all_core__ = [
"modules", "modules",
] ]
# ############################################# from .datasets import *
# Plugin Loader # Plugin Loader
# #############################################
import pkgutil import pkgutil
import pkg_resources import pkg_resources

View File

@ -1,7 +1,11 @@
"""ProtoTorch datasets.""" """ProtoTorch datasets."""
from .abstract import NumpyDataset
from .spiral import Spiral
from .tecator import Tecator from .tecator import Tecator
__all__ = [ __all__ = [
"NumpyDataset",
"Spiral",
"Tecator", "Tecator",
] ]

View File

@ -13,6 +13,7 @@ import torch
class NumpyDataset(torch.utils.data.TensorDataset): class NumpyDataset(torch.utils.data.TensorDataset):
"""Create a PyTorch TensorDataset from NumPy arrays."""
def __init__(self, *arrays): def __init__(self, *arrays):
tensors = [torch.Tensor(arr) for arr in arrays] tensors = [torch.Tensor(arr) for arr in arrays]
super().__init__(*tensors) super().__init__(*tensors)