feat: add CSVDataset
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
"""ProtoTorch datasets"""
|
||||
|
||||
from .abstract import NumpyDataset
|
||||
from .abstract import CSVDataset, NumpyDataset
|
||||
from .sklearn import (
|
||||
Blobs,
|
||||
Circles,
|
||||
|
@@ -10,6 +10,7 @@ https://github.com/pytorch/vision/blob/master/torchvision/datasets/mnist.py
|
||||
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
@@ -97,3 +98,16 @@ class NumpyDataset(torch.utils.data.TensorDataset):
|
||||
self.targets = torch.LongTensor(targets)
|
||||
tensors = [self.data, self.targets]
|
||||
super().__init__(*tensors)
|
||||
|
||||
|
||||
class CSVDataset(NumpyDataset):
|
||||
"""Create a Dataset from a CSV file."""
|
||||
def __init__(self, filepath, target_col=-1, delimiter=',', skip_header=0):
|
||||
raw = np.genfromtxt(
|
||||
filepath,
|
||||
delimiter=delimiter,
|
||||
skip_header=skip_header,
|
||||
)
|
||||
data = np.delete(raw, 1, target_col)
|
||||
targets = raw[:, target_col]
|
||||
super().__init__(data, targets)
|
||||
|
Reference in New Issue
Block a user