Add prototorch/datasets
This commit is contained in:
parent
4158586cb9
commit
a22c752342
0
prototorch/datasets/__init__.py
Normal file
0
prototorch/datasets/__init__.py
Normal file
87
prototorch/datasets/abstract.py
Normal file
87
prototorch/datasets/abstract.py
Normal file
@ -0,0 +1,87 @@
|
|||||||
|
"""ProtoTorch abstract datasets
|
||||||
|
|
||||||
|
Based on `torchvision.VisionDataset` and `torchvision.MNIST`
|
||||||
|
|
||||||
|
For the original code, see:
|
||||||
|
https://github.com/pytorch/vision/blob/master/torchvision/datasets/vision.py
|
||||||
|
https://github.com/pytorch/vision/blob/master/torchvision/datasets/mnist.py
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class Dataset(torch.utils.data.Dataset):
|
||||||
|
"""Abstract dataset class to be inherited"""
|
||||||
|
_repr_indent = 2
|
||||||
|
|
||||||
|
def __init__(self, root):
|
||||||
|
if isinstance(root, torch._six.string_classes):
|
||||||
|
root = os.path.expanduser(root)
|
||||||
|
self.root = root
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class ProtoDataset(Dataset):
|
||||||
|
"""Abstract dataset class to be inherited"""
|
||||||
|
training_file = 'training.pt'
|
||||||
|
test_file = 'test.pt'
|
||||||
|
|
||||||
|
def __init__(self, root, train=True, download=True, verbose=True):
|
||||||
|
super().__init__(root)
|
||||||
|
self.train = train # training set or test set
|
||||||
|
self.verbose = verbose
|
||||||
|
|
||||||
|
if download:
|
||||||
|
self.download()
|
||||||
|
|
||||||
|
if not self._check_exists():
|
||||||
|
raise RuntimeError('Dataset not found. '
|
||||||
|
'You can use download=True to download it')
|
||||||
|
|
||||||
|
data_file = self.training_file if self.train else self.test_file
|
||||||
|
|
||||||
|
self.data, self.targets = torch.load(
|
||||||
|
os.path.join(self.processed_folder, data_file))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def raw_folder(self):
|
||||||
|
return os.path.join(self.root, self.__class__.__name__, 'raw')
|
||||||
|
|
||||||
|
@property
|
||||||
|
def processed_folder(self):
|
||||||
|
return os.path.join(self.root, self.__class__.__name__, 'processed')
|
||||||
|
|
||||||
|
@property
|
||||||
|
def class_to_idx(self):
|
||||||
|
return {_class: i for i, _class in enumerate(self.classes)}
|
||||||
|
|
||||||
|
def _check_exists(self):
|
||||||
|
return (os.path.exists(
|
||||||
|
os.path.join(self.processed_folder, self.training_file))
|
||||||
|
and os.path.exists(
|
||||||
|
os.path.join(self.processed_folder, self.test_file)))
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
head = 'Dataset ' + self.__class__.__name__
|
||||||
|
body = ['Number of datapoints: {}'.format(self.__len__())]
|
||||||
|
if self.root is not None:
|
||||||
|
body.append('Root location: {}'.format(self.root))
|
||||||
|
body += self.extra_repr().splitlines()
|
||||||
|
lines = [head] + [' ' * self._repr_indent + line for line in body]
|
||||||
|
return '\n'.join(lines)
|
||||||
|
|
||||||
|
def extra_repr(self):
|
||||||
|
return f"Split: {'Train' if self.train is True else 'Test'}"
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.data)
|
||||||
|
|
||||||
|
def download(self):
|
||||||
|
raise NotImplementedError
|
94
prototorch/datasets/tecator.py
Normal file
94
prototorch/datasets/tecator.py
Normal file
@ -0,0 +1,94 @@
|
|||||||
|
"""Tecator dataset for classification
|
||||||
|
|
||||||
|
URL:
|
||||||
|
http://lib.stat.cmu.edu/datasets/tecator
|
||||||
|
|
||||||
|
LICENCE / TERMS / COPYRIGHT:
|
||||||
|
This is the Tecator data set: The task is to predict the fat content
|
||||||
|
of a meat sample on the basis of its near infrared absorbance spectrum.
|
||||||
|
-------------------------------------------------------------------------
|
||||||
|
1. Statement of permission from Tecator (the original data source)
|
||||||
|
|
||||||
|
These data are recorded on a Tecator Infratec Food and Feed Analyzer
|
||||||
|
working in the wavelength range 850 - 1050 nm by the Near Infrared
|
||||||
|
Transmission (NIT) principle. Each sample contains finely chopped pure
|
||||||
|
meat with different moisture, fat and protein contents.
|
||||||
|
|
||||||
|
If results from these data are used in a publication we want you to
|
||||||
|
mention the instrument and company name (Tecator) in the publication.
|
||||||
|
In addition, please send a preprint of your article to
|
||||||
|
|
||||||
|
Karin Thente, Tecator AB,
|
||||||
|
Box 70, S-263 21 Hoganas, Sweden
|
||||||
|
|
||||||
|
The data are available in the public domain with no responsability from
|
||||||
|
the original data source. The data can be redistributed as long as this
|
||||||
|
permission note is attached.
|
||||||
|
|
||||||
|
For more information about the instrument - call Perstorp Analytical's
|
||||||
|
representative in your area.
|
||||||
|
|
||||||
|
Description:
|
||||||
|
For each meat sample the data consists of a 100 channel spectrum of
|
||||||
|
absorbances and the contents of moisture (water), fat and protein.
|
||||||
|
The absorbance is -log10 of the transmittance
|
||||||
|
measured by the spectrometer. The three contents, measured in percent,
|
||||||
|
are determined by analytic chemistry.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from torchvision.datasets.utils import download_file_from_google_drive
|
||||||
|
|
||||||
|
from prototorch.datasets.abstract import ProtoDataset
|
||||||
|
|
||||||
|
|
||||||
|
class Tecator(ProtoDataset):
|
||||||
|
"""Tecator dataset for classification"""
|
||||||
|
resources = [('1MMuUK8V41IgNpnPDbg3E-QAL6wlErTk0',
|
||||||
|
'ba5607c580d0f91bb27dc29d13c2f8df')]
|
||||||
|
classes = ['0 - low_fat', '1 - high_fat']
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
img, target = self.data[index], int(self.targets[index])
|
||||||
|
return img, target
|
||||||
|
|
||||||
|
def download(self):
|
||||||
|
"""Download the data if it doesn't exist in already."""
|
||||||
|
if self._check_exists():
|
||||||
|
return
|
||||||
|
|
||||||
|
if self.verbose:
|
||||||
|
print('Making directories...')
|
||||||
|
os.makedirs(self.raw_folder, exist_ok=True)
|
||||||
|
os.makedirs(self.processed_folder, exist_ok=True)
|
||||||
|
|
||||||
|
if self.verbose:
|
||||||
|
print('Downloading...')
|
||||||
|
for fileid, md5 in self.resources:
|
||||||
|
filename = 'tecator.npz'
|
||||||
|
download_file_from_google_drive(fileid,
|
||||||
|
root=self.raw_folder,
|
||||||
|
filename=filename,
|
||||||
|
md5=md5)
|
||||||
|
|
||||||
|
if self.verbose:
|
||||||
|
print('Processing...')
|
||||||
|
with np.load(os.path.join(self.raw_folder, 'tecator.npz'),
|
||||||
|
allow_pickle=False) as f:
|
||||||
|
x_train, y_train = f['x_train'], f['y_train']
|
||||||
|
x_test, y_test = f['x_test'], f['y_test']
|
||||||
|
training_set = [torch.as_tensor(x_train), torch.as_tensor(y_train)]
|
||||||
|
test_set = [torch.as_tensor(x_test), torch.as_tensor(y_test)]
|
||||||
|
|
||||||
|
with open(os.path.join(self.processed_folder, self.training_file),
|
||||||
|
'wb') as f:
|
||||||
|
torch.save(training_set, f)
|
||||||
|
with open(os.path.join(self.processed_folder, self.test_file),
|
||||||
|
'wb') as f:
|
||||||
|
torch.save(test_set, f)
|
||||||
|
|
||||||
|
if self.verbose:
|
||||||
|
print('Done!')
|
Loading…
Reference in New Issue
Block a user