Tecator.data is a Tensor and Tecator.targets is a LongTensor

This commit is contained in:
Jensun Ravichandran 2021-06-01 17:28:37 +02:00
parent 4ca581909a
commit ff69eb1256

View File

@ -101,12 +101,12 @@ class Tecator(ProtoDataset):
x_train, y_train = f["x_train"], f["y_train"]
x_test, y_test = f["x_test"], f["y_test"]
training_set = [
torch.tensor(x_train, dtype=torch.float32),
torch.tensor(y_train),
torch.Tensor(x_train),
torch.LongTensor(y_train),
]
test_set = [
torch.tensor(x_test, dtype=torch.float32),
torch.tensor(y_test),
torch.Tensor(x_test),
torch.LongTensor(y_test),
]
with open(os.path.join(self.processed_folder, self.training_file),