Refactor datasets and use float32 instead of float64 in Tecator

This commit is contained in:
blackfly 2020-04-14 19:49:59 +02:00
parent a9d2855323
commit 553b1e1a65
2 changed files with 21 additions and 6 deletions

View File

@ -0,0 +1,7 @@
"""ProtoTorch datasets."""
from .tecator import Tecator
__all__ = [
'Tecator',
]

View File

@ -1,4 +1,4 @@
"""Tecator dataset for classification """Tecator dataset for classification.
URL: URL:
http://lib.stat.cmu.edu/datasets/tecator http://lib.stat.cmu.edu/datasets/tecator
@ -46,9 +46,11 @@ from prototorch.datasets.abstract import ProtoDataset
class Tecator(ProtoDataset): class Tecator(ProtoDataset):
"""Tecator dataset for classification""" """Tecator dataset for classification."""
resources = [('1MMuUK8V41IgNpnPDbg3E-QAL6wlErTk0', resources = [
'ba5607c580d0f91bb27dc29d13c2f8df')] ('1MMuUK8V41IgNpnPDbg3E-QAL6wlErTk0',
'ba5607c580d0f91bb27dc29d13c2f8df'),
] # (google_storage_id, md5hash)
classes = ['0 - low_fat', '1 - high_fat'] classes = ['0 - low_fat', '1 - high_fat']
def __getitem__(self, index): def __getitem__(self, index):
@ -80,8 +82,14 @@ class Tecator(ProtoDataset):
allow_pickle=False) as f: allow_pickle=False) as f:
x_train, y_train = f['x_train'], f['y_train'] x_train, y_train = f['x_train'], f['y_train']
x_test, y_test = f['x_test'], f['y_test'] x_test, y_test = f['x_test'], f['y_test']
training_set = [torch.as_tensor(x_train), torch.as_tensor(y_train)] training_set = [
test_set = [torch.as_tensor(x_test), torch.as_tensor(y_test)] torch.tensor(x_train, dtype=torch.float32),
torch.tensor(y_train),
]
test_set = [
torch.tensor(x_test, dtype=torch.float32),
torch.tensor(y_test),
]
with open(os.path.join(self.processed_folder, self.training_file), with open(os.path.join(self.processed_folder, self.training_file),
'wb') as f: 'wb') as f: