Add prototorch/datasets

This commit is contained in:
blackfly 2020-04-14 19:47:34 +02:00
parent 4158586cb9
commit a22c752342
3 changed files with 181 additions and 0 deletions

View File

View File

@ -0,0 +1,87 @@
"""ProtoTorch abstract datasets
Based on `torchvision.VisionDataset` and `torchvision.MNIST`
For the original code, see:
https://github.com/pytorch/vision/blob/master/torchvision/datasets/vision.py
https://github.com/pytorch/vision/blob/master/torchvision/datasets/mnist.py
"""
import os
import torch
class Dataset(torch.utils.data.Dataset):
"""Abstract dataset class to be inherited"""
_repr_indent = 2
def __init__(self, root):
if isinstance(root, torch._six.string_classes):
root = os.path.expanduser(root)
self.root = root
def __getitem__(self, index):
raise NotImplementedError
def __len__(self):
raise NotImplementedError
class ProtoDataset(Dataset):
"""Abstract dataset class to be inherited"""
training_file = 'training.pt'
test_file = 'test.pt'
def __init__(self, root, train=True, download=True, verbose=True):
super().__init__(root)
self.train = train # training set or test set
self.verbose = verbose
if download:
self.download()
if not self._check_exists():
raise RuntimeError('Dataset not found. '
'You can use download=True to download it')
data_file = self.training_file if self.train else self.test_file
self.data, self.targets = torch.load(
os.path.join(self.processed_folder, data_file))
@property
def raw_folder(self):
return os.path.join(self.root, self.__class__.__name__, 'raw')
@property
def processed_folder(self):
return os.path.join(self.root, self.__class__.__name__, 'processed')
@property
def class_to_idx(self):
return {_class: i for i, _class in enumerate(self.classes)}
def _check_exists(self):
return (os.path.exists(
os.path.join(self.processed_folder, self.training_file))
and os.path.exists(
os.path.join(self.processed_folder, self.test_file)))
def __repr__(self):
head = 'Dataset ' + self.__class__.__name__
body = ['Number of datapoints: {}'.format(self.__len__())]
if self.root is not None:
body.append('Root location: {}'.format(self.root))
body += self.extra_repr().splitlines()
lines = [head] + [' ' * self._repr_indent + line for line in body]
return '\n'.join(lines)
def extra_repr(self):
return f"Split: {'Train' if self.train is True else 'Test'}"
def __len__(self):
return len(self.data)
def download(self):
raise NotImplementedError

View File

@ -0,0 +1,94 @@
"""Tecator dataset for classification
URL:
http://lib.stat.cmu.edu/datasets/tecator
LICENCE / TERMS / COPYRIGHT:
This is the Tecator data set: The task is to predict the fat content
of a meat sample on the basis of its near infrared absorbance spectrum.
-------------------------------------------------------------------------
1. Statement of permission from Tecator (the original data source)
These data are recorded on a Tecator Infratec Food and Feed Analyzer
working in the wavelength range 850 - 1050 nm by the Near Infrared
Transmission (NIT) principle. Each sample contains finely chopped pure
meat with different moisture, fat and protein contents.
If results from these data are used in a publication we want you to
mention the instrument and company name (Tecator) in the publication.
In addition, please send a preprint of your article to
Karin Thente, Tecator AB,
Box 70, S-263 21 Hoganas, Sweden
The data are available in the public domain with no responsability from
the original data source. The data can be redistributed as long as this
permission note is attached.
For more information about the instrument - call Perstorp Analytical's
representative in your area.
Description:
For each meat sample the data consists of a 100 channel spectrum of
absorbances and the contents of moisture (water), fat and protein.
The absorbance is -log10 of the transmittance
measured by the spectrometer. The three contents, measured in percent,
are determined by analytic chemistry.
"""
import os
import numpy as np
import torch
from torchvision.datasets.utils import download_file_from_google_drive
from prototorch.datasets.abstract import ProtoDataset
class Tecator(ProtoDataset):
"""Tecator dataset for classification"""
resources = [('1MMuUK8V41IgNpnPDbg3E-QAL6wlErTk0',
'ba5607c580d0f91bb27dc29d13c2f8df')]
classes = ['0 - low_fat', '1 - high_fat']
def __getitem__(self, index):
img, target = self.data[index], int(self.targets[index])
return img, target
def download(self):
"""Download the data if it doesn't exist in already."""
if self._check_exists():
return
if self.verbose:
print('Making directories...')
os.makedirs(self.raw_folder, exist_ok=True)
os.makedirs(self.processed_folder, exist_ok=True)
if self.verbose:
print('Downloading...')
for fileid, md5 in self.resources:
filename = 'tecator.npz'
download_file_from_google_drive(fileid,
root=self.raw_folder,
filename=filename,
md5=md5)
if self.verbose:
print('Processing...')
with np.load(os.path.join(self.raw_folder, 'tecator.npz'),
allow_pickle=False) as f:
x_train, y_train = f['x_train'], f['y_train']
x_test, y_test = f['x_test'], f['y_test']
training_set = [torch.as_tensor(x_train), torch.as_tensor(y_train)]
test_set = [torch.as_tensor(x_test), torch.as_tensor(y_test)]
with open(os.path.join(self.processed_folder, self.training_file),
'wb') as f:
torch.save(training_set, f)
with open(os.path.join(self.processed_folder, self.test_file),
'wb') as f:
torch.save(test_set, f)
if self.verbose:
print('Done!')