Emo_Play / utils.py
DjPapzin's picture
Upload 34 files
6fd136c
raw
history blame
836 Bytes
import numpy as np
# Function to perform weighted random sampling without replacement
def weighted_random_sample(items: np.array, weights: np.array, n: int) -> np.array:
"""
Does np.random.choice but ensuring we don't have duplicates in the final result
Args:
items (np.array): _description_
weights (np.array): _description_
n (int): _description_
Returns:
np.array: _description_
"""
indices = np.arange(len(items))
out_indices = []
for _ in range(n):
chosen_index = np.random.choice(indices, p=weights)
out_indices.append(chosen_index)
mask = indices != chosen_index
indices = indices[mask]
weights = weights[mask]
if weights.sum() != 0:
weights = weights / weights.sum()
return items[out_indices]