Refactor datasets and use float32 instead of float64 in Tecator
This commit is contained in:
parent
a9d2855323
commit
553b1e1a65
@ -0,0 +1,7 @@
|
|||||||
|
"""ProtoTorch datasets."""
|
||||||
|
|
||||||
|
from .tecator import Tecator
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
'Tecator',
|
||||||
|
]
|
@ -1,4 +1,4 @@
|
|||||||
"""Tecator dataset for classification
|
"""Tecator dataset for classification.
|
||||||
|
|
||||||
URL:
|
URL:
|
||||||
http://lib.stat.cmu.edu/datasets/tecator
|
http://lib.stat.cmu.edu/datasets/tecator
|
||||||
@ -46,9 +46,11 @@ from prototorch.datasets.abstract import ProtoDataset
|
|||||||
|
|
||||||
|
|
||||||
class Tecator(ProtoDataset):
|
class Tecator(ProtoDataset):
|
||||||
"""Tecator dataset for classification"""
|
"""Tecator dataset for classification."""
|
||||||
resources = [('1MMuUK8V41IgNpnPDbg3E-QAL6wlErTk0',
|
resources = [
|
||||||
'ba5607c580d0f91bb27dc29d13c2f8df')]
|
('1MMuUK8V41IgNpnPDbg3E-QAL6wlErTk0',
|
||||||
|
'ba5607c580d0f91bb27dc29d13c2f8df'),
|
||||||
|
] # (google_storage_id, md5hash)
|
||||||
classes = ['0 - low_fat', '1 - high_fat']
|
classes = ['0 - low_fat', '1 - high_fat']
|
||||||
|
|
||||||
def __getitem__(self, index):
|
def __getitem__(self, index):
|
||||||
@ -80,8 +82,14 @@ class Tecator(ProtoDataset):
|
|||||||
allow_pickle=False) as f:
|
allow_pickle=False) as f:
|
||||||
x_train, y_train = f['x_train'], f['y_train']
|
x_train, y_train = f['x_train'], f['y_train']
|
||||||
x_test, y_test = f['x_test'], f['y_test']
|
x_test, y_test = f['x_test'], f['y_test']
|
||||||
training_set = [torch.as_tensor(x_train), torch.as_tensor(y_train)]
|
training_set = [
|
||||||
test_set = [torch.as_tensor(x_test), torch.as_tensor(y_test)]
|
torch.tensor(x_train, dtype=torch.float32),
|
||||||
|
torch.tensor(y_train),
|
||||||
|
]
|
||||||
|
test_set = [
|
||||||
|
torch.tensor(x_test, dtype=torch.float32),
|
||||||
|
torch.tensor(y_test),
|
||||||
|
]
|
||||||
|
|
||||||
with open(os.path.join(self.processed_folder, self.training_file),
|
with open(os.path.join(self.processed_folder, self.training_file),
|
||||||
'wb') as f:
|
'wb') as f:
|
||||||
|
Loading…
Reference in New Issue
Block a user