Add NumpyDataset.

This commit is contained in:
Alexander Engelsberger 2021-04-23 17:22:15 +02:00
parent 4540c8848e
commit e1d56595c1

View File

@ -12,6 +12,12 @@ import os
import torch import torch
class NumpyDataset(torch.utils.data.TensorDataset):
def __init__(self, *arrays):
tensors = [torch.Tensor(arr) for arr in arrays]
super().__init__(*tensors)
class Dataset(torch.utils.data.Dataset): class Dataset(torch.utils.data.Dataset):
"""Abstract dataset class to be inherited.""" """Abstract dataset class to be inherited."""
@ -44,15 +50,13 @@ class ProtoDataset(Dataset):
self._download() self._download()
if not self._check_exists(): if not self._check_exists():
raise RuntimeError( raise RuntimeError("Dataset not found. "
"Dataset not found. " "You can use download=True to download it" "You can use download=True to download it")
)
data_file = self.training_file if self.train else self.test_file data_file = self.training_file if self.train else self.test_file
self.data, self.targets = torch.load( self.data, self.targets = torch.load(
os.path.join(self.processed_folder, data_file) os.path.join(self.processed_folder, data_file))
)
@property @property
def raw_folder(self): def raw_folder(self):
@ -68,8 +72,9 @@ class ProtoDataset(Dataset):
def _check_exists(self): def _check_exists(self):
return os.path.exists( return os.path.exists(
os.path.join(self.processed_folder, self.training_file) os.path.join(
) and os.path.exists(os.path.join(self.processed_folder, self.test_file)) self.processed_folder, self.training_file)) and os.path.exists(
os.path.join(self.processed_folder, self.test_file))
def __repr__(self): def __repr__(self):
head = "Dataset " + self.__class__.__name__ head = "Dataset " + self.__class__.__name__