From e1d56595c147fde8c7587d39014d8b0943654d1c Mon Sep 17 00:00:00 2001 From: Alexander Engelsberger Date: Fri, 23 Apr 2021 17:22:15 +0200 Subject: [PATCH] Add NumpyDataset. --- prototorch/datasets/abstract.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/prototorch/datasets/abstract.py b/prototorch/datasets/abstract.py index 58dccee..7ff92aa 100644 --- a/prototorch/datasets/abstract.py +++ b/prototorch/datasets/abstract.py @@ -12,6 +12,12 @@ import os import torch +class NumpyDataset(torch.utils.data.TensorDataset): + def __init__(self, *arrays): + tensors = [torch.Tensor(arr) for arr in arrays] + super().__init__(*tensors) + + class Dataset(torch.utils.data.Dataset): """Abstract dataset class to be inherited.""" @@ -44,15 +50,13 @@ class ProtoDataset(Dataset): self._download() if not self._check_exists(): - raise RuntimeError( - "Dataset not found. " "You can use download=True to download it" - ) + raise RuntimeError("Dataset not found. " + "You can use download=True to download it") data_file = self.training_file if self.train else self.test_file self.data, self.targets = torch.load( - os.path.join(self.processed_folder, data_file) - ) + os.path.join(self.processed_folder, data_file)) @property def raw_folder(self): @@ -68,8 +72,9 @@ class ProtoDataset(Dataset): def _check_exists(self): return os.path.exists( - os.path.join(self.processed_folder, self.training_file) - ) and os.path.exists(os.path.join(self.processed_folder, self.test_file)) + os.path.join( + self.processed_folder, self.training_file)) and os.path.exists( + os.path.join(self.processed_folder, self.test_file)) def __repr__(self): head = "Dataset " + self.__class__.__name__