[FEATURE] Add more initializers
This commit is contained in:
@@ -67,17 +67,19 @@ def parse_distribution(user_distribution: Union[dict[int, int], dict[str, str],
|
||||
elif isinstance(user_distribution, list):
|
||||
return distribution_from_list(user_distribution, clabels)
|
||||
else:
|
||||
msg = f"`distribution` not understood." \
|
||||
msg = f"`distribution` was not understood." \
|
||||
f"You have provided: {user_distribution}."
|
||||
raise ValueError(msg)
|
||||
|
||||
|
||||
def parse_data_arg(data_arg: Union[Dataset, DataLoader, list, tuple]):
|
||||
"""Return data and target as torch tensors."""
|
||||
if isinstance(data_arg, Dataset):
|
||||
ds_size = len(data_arg)
|
||||
data_arg = DataLoader(data_arg, batch_size=ds_size)
|
||||
loader = DataLoader(data_arg, batch_size=ds_size)
|
||||
data, targets = next(iter(loader))
|
||||
|
||||
if isinstance(data_arg, DataLoader):
|
||||
elif isinstance(data_arg, DataLoader):
|
||||
data = torch.tensor([])
|
||||
targets = torch.tensor([])
|
||||
for x, y in data_arg:
|
||||
@@ -87,11 +89,11 @@ def parse_data_arg(data_arg: Union[Dataset, DataLoader, list, tuple]):
|
||||
assert len(data_arg) == 2
|
||||
data, targets = data_arg
|
||||
if not isinstance(data, torch.Tensor):
|
||||
wmsg = f"Converting data to {torch.Tensor}."
|
||||
wmsg = f"Converting data to {torch.Tensor}..."
|
||||
warnings.warn(wmsg)
|
||||
data = torch.Tensor(data)
|
||||
if not isinstance(targets, torch.LongTensor):
|
||||
wmsg = f"Converting targets to {torch.LongTensor}."
|
||||
wmsg = f"Converting targets to {torch.LongTensor}..."
|
||||
warnings.warn(wmsg)
|
||||
targets = torch.LongTensor(targets)
|
||||
return data, targets
|
||||
|
Reference in New Issue
Block a user