Expose prototorch.datasets
This commit is contained in:
		@@ -1,8 +1,6 @@
 | 
				
			|||||||
"""ProtoTorch package."""
 | 
					"""ProtoTorch package."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# #############################################
 | 
					 | 
				
			||||||
# Core Setup
 | 
					# Core Setup
 | 
				
			||||||
# #############################################
 | 
					 | 
				
			||||||
__version__ = "0.3.0-dev0"
 | 
					__version__ = "0.3.0-dev0"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
__all_core__ = [
 | 
					__all_core__ = [
 | 
				
			||||||
@@ -11,9 +9,9 @@ __all_core__ = [
 | 
				
			|||||||
    "modules",
 | 
					    "modules",
 | 
				
			||||||
]
 | 
					]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# #############################################
 | 
					from .datasets import *
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# Plugin Loader
 | 
					# Plugin Loader
 | 
				
			||||||
# #############################################
 | 
					 | 
				
			||||||
import pkgutil
 | 
					import pkgutil
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import pkg_resources
 | 
					import pkg_resources
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -1,7 +1,11 @@
 | 
				
			|||||||
"""ProtoTorch datasets."""
 | 
					"""ProtoTorch datasets."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from .abstract import NumpyDataset
 | 
				
			||||||
 | 
					from .spiral import Spiral
 | 
				
			||||||
from .tecator import Tecator
 | 
					from .tecator import Tecator
 | 
				
			||||||
 | 
					
 | 
				
			||||||
__all__ = [
 | 
					__all__ = [
 | 
				
			||||||
 | 
					    "NumpyDataset",
 | 
				
			||||||
 | 
					    "Spiral",
 | 
				
			||||||
    "Tecator",
 | 
					    "Tecator",
 | 
				
			||||||
]
 | 
					]
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -13,6 +13,7 @@ import torch
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class NumpyDataset(torch.utils.data.TensorDataset):
 | 
					class NumpyDataset(torch.utils.data.TensorDataset):
 | 
				
			||||||
 | 
					    """Create a PyTorch TensorDataset from NumPy arrays."""
 | 
				
			||||||
    def __init__(self, *arrays):
 | 
					    def __init__(self, *arrays):
 | 
				
			||||||
        tensors = [torch.Tensor(arr) for arr in arrays]
 | 
					        tensors = [torch.Tensor(arr) for arr in arrays]
 | 
				
			||||||
        super().__init__(*tensors)
 | 
					        super().__init__(*tensors)
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user