Expose prototorch.datasets
This commit is contained in:
parent
19475d7e2b
commit
47b4b9bcb1
@ -1,8 +1,6 @@
|
|||||||
"""ProtoTorch package."""
|
"""ProtoTorch package."""
|
||||||
|
|
||||||
# #############################################
|
|
||||||
# Core Setup
|
# Core Setup
|
||||||
# #############################################
|
|
||||||
__version__ = "0.3.0-dev0"
|
__version__ = "0.3.0-dev0"
|
||||||
|
|
||||||
__all_core__ = [
|
__all_core__ = [
|
||||||
@ -11,9 +9,9 @@ __all_core__ = [
|
|||||||
"modules",
|
"modules",
|
||||||
]
|
]
|
||||||
|
|
||||||
# #############################################
|
from .datasets import *
|
||||||
|
|
||||||
# Plugin Loader
|
# Plugin Loader
|
||||||
# #############################################
|
|
||||||
import pkgutil
|
import pkgutil
|
||||||
|
|
||||||
import pkg_resources
|
import pkg_resources
|
||||||
|
@ -1,7 +1,11 @@
|
|||||||
"""ProtoTorch datasets."""
|
"""ProtoTorch datasets."""
|
||||||
|
|
||||||
|
from .abstract import NumpyDataset
|
||||||
|
from .spiral import Spiral
|
||||||
from .tecator import Tecator
|
from .tecator import Tecator
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
"NumpyDataset",
|
||||||
|
"Spiral",
|
||||||
"Tecator",
|
"Tecator",
|
||||||
]
|
]
|
||||||
|
@ -13,6 +13,7 @@ import torch
|
|||||||
|
|
||||||
|
|
||||||
class NumpyDataset(torch.utils.data.TensorDataset):
|
class NumpyDataset(torch.utils.data.TensorDataset):
|
||||||
|
"""Create a PyTorch TensorDataset from NumPy arrays."""
|
||||||
def __init__(self, *arrays):
|
def __init__(self, *arrays):
|
||||||
tensors = [torch.Tensor(arr) for arr in arrays]
|
tensors = [torch.Tensor(arr) for arr in arrays]
|
||||||
super().__init__(*tensors)
|
super().__init__(*tensors)
|
||||||
|
Loading…
Reference in New Issue
Block a user