Refactor datasets and use float32 instead of float64 in Tecator
This commit is contained in:
parent
a9d2855323
commit
553b1e1a65
@ -0,0 +1,7 @@
|
||||
"""ProtoTorch datasets."""
|
||||
|
||||
from .tecator import Tecator
|
||||
|
||||
__all__ = [
|
||||
'Tecator',
|
||||
]
|
@ -1,4 +1,4 @@
|
||||
"""Tecator dataset for classification
|
||||
"""Tecator dataset for classification.
|
||||
|
||||
URL:
|
||||
http://lib.stat.cmu.edu/datasets/tecator
|
||||
@ -46,9 +46,11 @@ from prototorch.datasets.abstract import ProtoDataset
|
||||
|
||||
|
||||
class Tecator(ProtoDataset):
|
||||
"""Tecator dataset for classification"""
|
||||
resources = [('1MMuUK8V41IgNpnPDbg3E-QAL6wlErTk0',
|
||||
'ba5607c580d0f91bb27dc29d13c2f8df')]
|
||||
"""Tecator dataset for classification."""
|
||||
resources = [
|
||||
('1MMuUK8V41IgNpnPDbg3E-QAL6wlErTk0',
|
||||
'ba5607c580d0f91bb27dc29d13c2f8df'),
|
||||
] # (google_storage_id, md5hash)
|
||||
classes = ['0 - low_fat', '1 - high_fat']
|
||||
|
||||
def __getitem__(self, index):
|
||||
@ -80,8 +82,14 @@ class Tecator(ProtoDataset):
|
||||
allow_pickle=False) as f:
|
||||
x_train, y_train = f['x_train'], f['y_train']
|
||||
x_test, y_test = f['x_test'], f['y_test']
|
||||
training_set = [torch.as_tensor(x_train), torch.as_tensor(y_train)]
|
||||
test_set = [torch.as_tensor(x_test), torch.as_tensor(y_test)]
|
||||
training_set = [
|
||||
torch.tensor(x_train, dtype=torch.float32),
|
||||
torch.tensor(y_train),
|
||||
]
|
||||
test_set = [
|
||||
torch.tensor(x_test, dtype=torch.float32),
|
||||
torch.tensor(y_test),
|
||||
]
|
||||
|
||||
with open(os.path.join(self.processed_folder, self.training_file),
|
||||
'wb') as f:
|
||||
|
Loading…
Reference in New Issue
Block a user