[BUGFIX] Initializers can handle Dataloaders now
This commit is contained in:
parent
14508f0600
commit
ee30d4da5b
@ -7,17 +7,18 @@ import torch
|
|||||||
from torch.utils.data import DataLoader, Dataset
|
from torch.utils.data import DataLoader, Dataset
|
||||||
|
|
||||||
|
|
||||||
def parse_data_arg(data):
|
def parse_data_arg(data_arg):
|
||||||
if isinstance(data, Dataset):
|
if isinstance(data_arg, Dataset):
|
||||||
data, labels = next(iter(DataLoader(data, batch_size=len(data))))
|
data_arg = DataLoader(data_arg, batch_size=len(data_arg))
|
||||||
elif isinstance(data, DataLoader):
|
|
||||||
|
if isinstance(data_arg, DataLoader):
|
||||||
data = torch.tensor([])
|
data = torch.tensor([])
|
||||||
labels = torch.tensor([])
|
labels = torch.tensor([])
|
||||||
for x, y in data:
|
for x, y in data_arg:
|
||||||
data = torch.cat([data, x])
|
data = torch.cat([data, x])
|
||||||
labels = torch.cat([labels, y])
|
labels = torch.cat([labels, y])
|
||||||
else:
|
else:
|
||||||
data, labels = data
|
data, labels = data_arg
|
||||||
if not isinstance(data, torch.Tensor):
|
if not isinstance(data, torch.Tensor):
|
||||||
wmsg = f"Converting data to {torch.Tensor}."
|
wmsg = f"Converting data to {torch.Tensor}."
|
||||||
warnings.warn(wmsg)
|
warnings.warn(wmsg)
|
||||||
|
Loading…
Reference in New Issue
Block a user