File size: 836 Bytes
6fd136c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 |
import numpy as np
# Function to perform weighted random sampling without replacement
def weighted_random_sample(items: np.array, weights: np.array, n: int) -> np.array:
"""
Does np.random.choice but ensuring we don't have duplicates in the final result
Args:
items (np.array): _description_
weights (np.array): _description_
n (int): _description_
Returns:
np.array: _description_
"""
indices = np.arange(len(items))
out_indices = []
for _ in range(n):
chosen_index = np.random.choice(indices, p=weights)
out_indices.append(chosen_index)
mask = indices != chosen_index
indices = indices[mask]
weights = weights[mask]
if weights.sum() != 0:
weights = weights / weights.sum()
return items[out_indices]
|