[BUGFIX] Stratified functions work on GPU now
This commit is contained in:
parent
1e0a8392a2
commit
b724a28a6f
@ -39,7 +39,7 @@ def stratify_with(values: torch.Tensor,
|
|||||||
# skip if stratification is trivial
|
# skip if stratification is trivial
|
||||||
return values
|
return values
|
||||||
batch_size = values.size()[0]
|
batch_size = values.size()[0]
|
||||||
winning_values = torch.zeros(num_classes, batch_size)
|
winning_values = torch.zeros(num_classes, batch_size, device=labels.device)
|
||||||
filler = torch.full_like(values.T, fill_value=fill_value)
|
filler = torch.full_like(values.T, fill_value=fill_value)
|
||||||
for i, cl in enumerate(clabels):
|
for i, cl in enumerate(clabels):
|
||||||
matcher = torch.eq(labels.unsqueeze(dim=1), cl)
|
matcher = torch.eq(labels.unsqueeze(dim=1), cl)
|
||||||
|
Loading…
Reference in New Issue
Block a user