diff --git a/prototorch/datasets/__init__.py b/prototorch/datasets/__init__.py index e69de29..3607893 100644 --- a/prototorch/datasets/__init__.py +++ b/prototorch/datasets/__init__.py @@ -0,0 +1,7 @@ +"""ProtoTorch datasets.""" + +from .tecator import Tecator + +__all__ = [ + 'Tecator', +] diff --git a/prototorch/datasets/tecator.py b/prototorch/datasets/tecator.py index a65eaaf..b86f059 100644 --- a/prototorch/datasets/tecator.py +++ b/prototorch/datasets/tecator.py @@ -1,4 +1,4 @@ -"""Tecator dataset for classification +"""Tecator dataset for classification. URL: http://lib.stat.cmu.edu/datasets/tecator @@ -46,9 +46,11 @@ from prototorch.datasets.abstract import ProtoDataset class Tecator(ProtoDataset): - """Tecator dataset for classification""" - resources = [('1MMuUK8V41IgNpnPDbg3E-QAL6wlErTk0', - 'ba5607c580d0f91bb27dc29d13c2f8df')] + """Tecator dataset for classification.""" + resources = [ + ('1MMuUK8V41IgNpnPDbg3E-QAL6wlErTk0', + 'ba5607c580d0f91bb27dc29d13c2f8df'), + ] # (google_storage_id, md5hash) classes = ['0 - low_fat', '1 - high_fat'] def __getitem__(self, index): @@ -80,8 +82,14 @@ class Tecator(ProtoDataset): allow_pickle=False) as f: x_train, y_train = f['x_train'], f['y_train'] x_test, y_test = f['x_test'], f['y_test'] - training_set = [torch.as_tensor(x_train), torch.as_tensor(y_train)] - test_set = [torch.as_tensor(x_test), torch.as_tensor(y_test)] + training_set = [ + torch.tensor(x_train, dtype=torch.float32), + torch.tensor(y_train), + ] + test_set = [ + torch.tensor(x_test, dtype=torch.float32), + torch.tensor(y_test), + ] with open(os.path.join(self.processed_folder, self.training_file), 'wb') as f: