feat: add CSVDataset

This commit is contained in:
Jensun Ravichandran 2021-07-04 16:30:01 +02:00
parent eb79b703d8
commit fdb9a7c66d
No known key found for this signature in database
GPG Key ID: 1BB4A641722D6B23
3 changed files with 30 additions and 1 deletions

View File

@ -1,6 +1,6 @@
"""ProtoTorch datasets""" """ProtoTorch datasets"""
from .abstract import NumpyDataset from .abstract import CSVDataset, NumpyDataset
from .sklearn import ( from .sklearn import (
Blobs, Blobs,
Circles, Circles,

View File

@ -10,6 +10,7 @@ https://github.com/pytorch/vision/blob/master/torchvision/datasets/mnist.py
import os import os
import numpy as np
import torch import torch
@ -97,3 +98,16 @@ class NumpyDataset(torch.utils.data.TensorDataset):
self.targets = torch.LongTensor(targets) self.targets = torch.LongTensor(targets)
tensors = [self.data, self.targets] tensors = [self.data, self.targets]
super().__init__(*tensors) super().__init__(*tensors)
class CSVDataset(NumpyDataset):
"""Create a Dataset from a CSV file."""
def __init__(self, filepath, target_col=-1, delimiter=',', skip_header=0):
raw = np.genfromtxt(
filepath,
delimiter=delimiter,
skip_header=skip_header,
)
data = np.delete(raw, 1, target_col)
targets = raw[:, target_col]
super().__init__(data, targets)

View File

@ -49,6 +49,21 @@ class TestNumpyDataset(unittest.TestCase):
self.assertEqual(len(ds), 3) self.assertEqual(len(ds), 3)
class TestCSVDataset(unittest.TestCase):
def setUp(self):
data = np.random.rand(100, 4)
targets = np.random.randint(2, size=(100, 1))
arr = np.hstack([data, targets])
np.savetxt("./artifacts/test.csv", arr, delimiter=",")
def test_len(self):
ds = pt.datasets.CSVDataset("./artifacts/test.csv")
self.assertEqual(len(ds), 100)
def tearDown(self):
os.remove("./artifacts/test.csv")
class TestSpiral(unittest.TestCase): class TestSpiral(unittest.TestCase):
def test_init(self): def test_init(self):
ds = pt.datasets.Spiral(num_samples=10) ds = pt.datasets.Spiral(num_samples=10)