Add NumpyDataset.
This commit is contained in:
parent
4540c8848e
commit
e1d56595c1
@ -12,6 +12,12 @@ import os
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class NumpyDataset(torch.utils.data.TensorDataset):
|
||||||
|
def __init__(self, *arrays):
|
||||||
|
tensors = [torch.Tensor(arr) for arr in arrays]
|
||||||
|
super().__init__(*tensors)
|
||||||
|
|
||||||
|
|
||||||
class Dataset(torch.utils.data.Dataset):
|
class Dataset(torch.utils.data.Dataset):
|
||||||
"""Abstract dataset class to be inherited."""
|
"""Abstract dataset class to be inherited."""
|
||||||
|
|
||||||
@ -44,15 +50,13 @@ class ProtoDataset(Dataset):
|
|||||||
self._download()
|
self._download()
|
||||||
|
|
||||||
if not self._check_exists():
|
if not self._check_exists():
|
||||||
raise RuntimeError(
|
raise RuntimeError("Dataset not found. "
|
||||||
"Dataset not found. " "You can use download=True to download it"
|
"You can use download=True to download it")
|
||||||
)
|
|
||||||
|
|
||||||
data_file = self.training_file if self.train else self.test_file
|
data_file = self.training_file if self.train else self.test_file
|
||||||
|
|
||||||
self.data, self.targets = torch.load(
|
self.data, self.targets = torch.load(
|
||||||
os.path.join(self.processed_folder, data_file)
|
os.path.join(self.processed_folder, data_file))
|
||||||
)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def raw_folder(self):
|
def raw_folder(self):
|
||||||
@ -68,8 +72,9 @@ class ProtoDataset(Dataset):
|
|||||||
|
|
||||||
def _check_exists(self):
|
def _check_exists(self):
|
||||||
return os.path.exists(
|
return os.path.exists(
|
||||||
os.path.join(self.processed_folder, self.training_file)
|
os.path.join(
|
||||||
) and os.path.exists(os.path.join(self.processed_folder, self.test_file))
|
self.processed_folder, self.training_file)) and os.path.exists(
|
||||||
|
os.path.join(self.processed_folder, self.test_file))
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
head = "Dataset " + self.__class__.__name__
|
head = "Dataset " + self.__class__.__name__
|
||||||
|
Loading…
Reference in New Issue
Block a user