diff --git a/prototorch/functions/initializers.py b/prototorch/functions/initializers.py index 18a4e04..4f9d74d 100644 --- a/prototorch/functions/initializers.py +++ b/prototorch/functions/initializers.py @@ -69,16 +69,13 @@ def stratified_mean(x_train, y_train, prototype_distribution): @register_initializer def stratified_random(x_train, y_train, prototype_distribution): - gen = torch.manual_seed(torch.initial_seed()) nprotos = torch.sum(prototype_distribution) pdim = x_train.shape[1] protos = torch.empty(nprotos, pdim) plabels = labels_from(prototype_distribution) for i, l in enumerate(plabels): xl = x_train[y_train == l] - rand_index = torch.zeros(1).long().random_(0, - xl.shape[1] - 1, - generator=gen) + rand_index = torch.zeros(1).long().random_(0, xl.shape[0] - 1) random_xl = xl[rand_index] protos[i] = random_xl return protos, plabels