Refactor datasets and use float32 instead of float64 in Tecator

This commit is contained in:
blackfly 2020-04-14 19:49:59 +02:00
parent a9d2855323
commit 553b1e1a65
2 changed files with 21 additions and 6 deletions

View File

@ -0,0 +1,7 @@
"""ProtoTorch datasets."""
from .tecator import Tecator
__all__ = [
'Tecator',
]

View File

@ -1,4 +1,4 @@
"""Tecator dataset for classification
"""Tecator dataset for classification.
URL:
http://lib.stat.cmu.edu/datasets/tecator
@ -46,9 +46,11 @@ from prototorch.datasets.abstract import ProtoDataset
class Tecator(ProtoDataset):
"""Tecator dataset for classification"""
resources = [('1MMuUK8V41IgNpnPDbg3E-QAL6wlErTk0',
'ba5607c580d0f91bb27dc29d13c2f8df')]
"""Tecator dataset for classification."""
resources = [
('1MMuUK8V41IgNpnPDbg3E-QAL6wlErTk0',
'ba5607c580d0f91bb27dc29d13c2f8df'),
] # (google_storage_id, md5hash)
classes = ['0 - low_fat', '1 - high_fat']
def __getitem__(self, index):
@ -80,8 +82,14 @@ class Tecator(ProtoDataset):
allow_pickle=False) as f:
x_train, y_train = f['x_train'], f['y_train']
x_test, y_test = f['x_test'], f['y_test']
training_set = [torch.as_tensor(x_train), torch.as_tensor(y_train)]
test_set = [torch.as_tensor(x_test), torch.as_tensor(y_test)]
training_set = [
torch.tensor(x_train, dtype=torch.float32),
torch.tensor(y_train),
]
test_set = [
torch.tensor(x_test, dtype=torch.float32),
torch.tensor(y_test),
]
with open(os.path.join(self.processed_folder, self.training_file),
'wb') as f: