Add NumpyDataset.
This commit is contained in:
		| @@ -12,6 +12,12 @@ import os | |||||||
| import torch | import torch | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class NumpyDataset(torch.utils.data.TensorDataset): | ||||||
|  |     def __init__(self, *arrays): | ||||||
|  |         tensors = [torch.Tensor(arr) for arr in arrays] | ||||||
|  |         super().__init__(*tensors) | ||||||
|  |  | ||||||
|  |  | ||||||
| class Dataset(torch.utils.data.Dataset): | class Dataset(torch.utils.data.Dataset): | ||||||
|     """Abstract dataset class to be inherited.""" |     """Abstract dataset class to be inherited.""" | ||||||
|  |  | ||||||
| @@ -44,15 +50,13 @@ class ProtoDataset(Dataset): | |||||||
|             self._download() |             self._download() | ||||||
|  |  | ||||||
|         if not self._check_exists(): |         if not self._check_exists(): | ||||||
|             raise RuntimeError( |             raise RuntimeError("Dataset not found. " | ||||||
|                 "Dataset not found. " "You can use download=True to download it" |                                "You can use download=True to download it") | ||||||
|             ) |  | ||||||
|  |  | ||||||
|         data_file = self.training_file if self.train else self.test_file |         data_file = self.training_file if self.train else self.test_file | ||||||
|  |  | ||||||
|         self.data, self.targets = torch.load( |         self.data, self.targets = torch.load( | ||||||
|             os.path.join(self.processed_folder, data_file) |             os.path.join(self.processed_folder, data_file)) | ||||||
|         ) |  | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     def raw_folder(self): |     def raw_folder(self): | ||||||
| @@ -68,8 +72,9 @@ class ProtoDataset(Dataset): | |||||||
|  |  | ||||||
|     def _check_exists(self): |     def _check_exists(self): | ||||||
|         return os.path.exists( |         return os.path.exists( | ||||||
|             os.path.join(self.processed_folder, self.training_file) |             os.path.join( | ||||||
|         ) and os.path.exists(os.path.join(self.processed_folder, self.test_file)) |                 self.processed_folder, self.training_file)) and os.path.exists( | ||||||
|  |                     os.path.join(self.processed_folder, self.test_file)) | ||||||
|  |  | ||||||
|     def __repr__(self): |     def __repr__(self): | ||||||
|         head = "Dataset " + self.__class__.__name__ |         head = "Dataset " + self.__class__.__name__ | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user