[FEATURE] Change NumpyDataset.data to torch.Tensor
This commit is contained in:
parent
2722d976f5
commit
4ca581909a
@ -15,9 +15,9 @@ import torch
|
||||
class NumpyDataset(torch.utils.data.TensorDataset):
|
||||
"""Create a PyTorch TensorDataset from NumPy arrays."""
|
||||
def __init__(self, data, targets):
|
||||
self.data = data
|
||||
self.targets = targets
|
||||
tensors = [torch.Tensor(data), torch.Tensor(targets)]
|
||||
self.data = torch.Tensor(data)
|
||||
self.targets = torch.LongTensor(targets)
|
||||
tensors = [self.data, self.targets]
|
||||
super().__init__(*tensors)
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user