From 4ca581909a86f426c768206a495ff2b17a4e2c1f Mon Sep 17 00:00:00 2001 From: Alexander Engelsberger Date: Tue, 1 Jun 2021 17:17:42 +0200 Subject: [PATCH] [FEATURE] Change NumpyDataset.data to torch.Tensor --- prototorch/datasets/abstract.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/prototorch/datasets/abstract.py b/prototorch/datasets/abstract.py index e6d7b7b..e941c95 100644 --- a/prototorch/datasets/abstract.py +++ b/prototorch/datasets/abstract.py @@ -15,9 +15,9 @@ import torch class NumpyDataset(torch.utils.data.TensorDataset): """Create a PyTorch TensorDataset from NumPy arrays.""" def __init__(self, data, targets): - self.data = data - self.targets = targets - tensors = [torch.Tensor(data), torch.Tensor(targets)] + self.data = torch.Tensor(data) + self.targets = torch.LongTensor(targets) + tensors = [self.data, self.targets] super().__init__(*tensors)