From a22c752342af4deb4c71fdb46cf5293c7a985dbd Mon Sep 17 00:00:00 2001 From: blackfly Date: Tue, 14 Apr 2020 19:47:34 +0200 Subject: [PATCH] Add prototorch/datasets --- prototorch/datasets/__init__.py | 0 prototorch/datasets/abstract.py | 87 ++++++++++++++++++++++++++++++ prototorch/datasets/tecator.py | 94 +++++++++++++++++++++++++++++++++ 3 files changed, 181 insertions(+) create mode 100644 prototorch/datasets/__init__.py create mode 100644 prototorch/datasets/abstract.py create mode 100644 prototorch/datasets/tecator.py diff --git a/prototorch/datasets/__init__.py b/prototorch/datasets/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/prototorch/datasets/abstract.py b/prototorch/datasets/abstract.py new file mode 100644 index 0000000..7817bb8 --- /dev/null +++ b/prototorch/datasets/abstract.py @@ -0,0 +1,87 @@ +"""ProtoTorch abstract datasets + +Based on `torchvision.VisionDataset` and `torchvision.MNIST` + +For the original code, see: +https://github.com/pytorch/vision/blob/master/torchvision/datasets/vision.py +https://github.com/pytorch/vision/blob/master/torchvision/datasets/mnist.py +""" + +import os + +import torch + + +class Dataset(torch.utils.data.Dataset): + """Abstract dataset class to be inherited""" + _repr_indent = 2 + + def __init__(self, root): + if isinstance(root, torch._six.string_classes): + root = os.path.expanduser(root) + self.root = root + + def __getitem__(self, index): + raise NotImplementedError + + def __len__(self): + raise NotImplementedError + + +class ProtoDataset(Dataset): + """Abstract dataset class to be inherited""" + training_file = 'training.pt' + test_file = 'test.pt' + + def __init__(self, root, train=True, download=True, verbose=True): + super().__init__(root) + self.train = train # training set or test set + self.verbose = verbose + + if download: + self.download() + + if not self._check_exists(): + raise RuntimeError('Dataset not found. ' + 'You can use download=True to download it') + + data_file = self.training_file if self.train else self.test_file + + self.data, self.targets = torch.load( + os.path.join(self.processed_folder, data_file)) + + @property + def raw_folder(self): + return os.path.join(self.root, self.__class__.__name__, 'raw') + + @property + def processed_folder(self): + return os.path.join(self.root, self.__class__.__name__, 'processed') + + @property + def class_to_idx(self): + return {_class: i for i, _class in enumerate(self.classes)} + + def _check_exists(self): + return (os.path.exists( + os.path.join(self.processed_folder, self.training_file)) + and os.path.exists( + os.path.join(self.processed_folder, self.test_file))) + + def __repr__(self): + head = 'Dataset ' + self.__class__.__name__ + body = ['Number of datapoints: {}'.format(self.__len__())] + if self.root is not None: + body.append('Root location: {}'.format(self.root)) + body += self.extra_repr().splitlines() + lines = [head] + [' ' * self._repr_indent + line for line in body] + return '\n'.join(lines) + + def extra_repr(self): + return f"Split: {'Train' if self.train is True else 'Test'}" + + def __len__(self): + return len(self.data) + + def download(self): + raise NotImplementedError diff --git a/prototorch/datasets/tecator.py b/prototorch/datasets/tecator.py new file mode 100644 index 0000000..a65eaaf --- /dev/null +++ b/prototorch/datasets/tecator.py @@ -0,0 +1,94 @@ +"""Tecator dataset for classification + +URL: + http://lib.stat.cmu.edu/datasets/tecator + +LICENCE / TERMS / COPYRIGHT: + This is the Tecator data set: The task is to predict the fat content + of a meat sample on the basis of its near infrared absorbance spectrum. + ------------------------------------------------------------------------- + 1. Statement of permission from Tecator (the original data source) + + These data are recorded on a Tecator Infratec Food and Feed Analyzer + working in the wavelength range 850 - 1050 nm by the Near Infrared + Transmission (NIT) principle. Each sample contains finely chopped pure + meat with different moisture, fat and protein contents. + + If results from these data are used in a publication we want you to + mention the instrument and company name (Tecator) in the publication. + In addition, please send a preprint of your article to + + Karin Thente, Tecator AB, + Box 70, S-263 21 Hoganas, Sweden + + The data are available in the public domain with no responsability from + the original data source. The data can be redistributed as long as this + permission note is attached. + + For more information about the instrument - call Perstorp Analytical's + representative in your area. + +Description: + For each meat sample the data consists of a 100 channel spectrum of + absorbances and the contents of moisture (water), fat and protein. + The absorbance is -log10 of the transmittance + measured by the spectrometer. The three contents, measured in percent, + are determined by analytic chemistry. +""" + +import os + +import numpy as np +import torch +from torchvision.datasets.utils import download_file_from_google_drive + +from prototorch.datasets.abstract import ProtoDataset + + +class Tecator(ProtoDataset): + """Tecator dataset for classification""" + resources = [('1MMuUK8V41IgNpnPDbg3E-QAL6wlErTk0', + 'ba5607c580d0f91bb27dc29d13c2f8df')] + classes = ['0 - low_fat', '1 - high_fat'] + + def __getitem__(self, index): + img, target = self.data[index], int(self.targets[index]) + return img, target + + def download(self): + """Download the data if it doesn't exist in already.""" + if self._check_exists(): + return + + if self.verbose: + print('Making directories...') + os.makedirs(self.raw_folder, exist_ok=True) + os.makedirs(self.processed_folder, exist_ok=True) + + if self.verbose: + print('Downloading...') + for fileid, md5 in self.resources: + filename = 'tecator.npz' + download_file_from_google_drive(fileid, + root=self.raw_folder, + filename=filename, + md5=md5) + + if self.verbose: + print('Processing...') + with np.load(os.path.join(self.raw_folder, 'tecator.npz'), + allow_pickle=False) as f: + x_train, y_train = f['x_train'], f['y_train'] + x_test, y_test = f['x_test'], f['y_test'] + training_set = [torch.as_tensor(x_train), torch.as_tensor(y_train)] + test_set = [torch.as_tensor(x_test), torch.as_tensor(y_test)] + + with open(os.path.join(self.processed_folder, self.training_file), + 'wb') as f: + torch.save(training_set, f) + with open(os.path.join(self.processed_folder, self.test_file), + 'wb') as f: + torch.save(test_set, f) + + if self.verbose: + print('Done!')