Expose prototorch.datasets
This commit is contained in:
parent
19475d7e2b
commit
47b4b9bcb1
@ -1,8 +1,6 @@
|
||||
"""ProtoTorch package."""
|
||||
|
||||
# #############################################
|
||||
# Core Setup
|
||||
# #############################################
|
||||
__version__ = "0.3.0-dev0"
|
||||
|
||||
__all_core__ = [
|
||||
@ -11,9 +9,9 @@ __all_core__ = [
|
||||
"modules",
|
||||
]
|
||||
|
||||
# #############################################
|
||||
from .datasets import *
|
||||
|
||||
# Plugin Loader
|
||||
# #############################################
|
||||
import pkgutil
|
||||
|
||||
import pkg_resources
|
||||
|
@ -1,7 +1,11 @@
|
||||
"""ProtoTorch datasets."""
|
||||
|
||||
from .abstract import NumpyDataset
|
||||
from .spiral import Spiral
|
||||
from .tecator import Tecator
|
||||
|
||||
__all__ = [
|
||||
"NumpyDataset",
|
||||
"Spiral",
|
||||
"Tecator",
|
||||
]
|
||||
|
@ -13,6 +13,7 @@ import torch
|
||||
|
||||
|
||||
class NumpyDataset(torch.utils.data.TensorDataset):
|
||||
"""Create a PyTorch TensorDataset from NumPy arrays."""
|
||||
def __init__(self, *arrays):
|
||||
tensors = [torch.Tensor(arr) for arr in arrays]
|
||||
super().__init__(*tensors)
|
||||
|
Loading…
Reference in New Issue
Block a user