Spaces:
Runtime error
Runtime error
| import numpy as np | |
| def weighted_random_sample(items: np.array, weights: np.array, n: int) -> np.array: | |
| """ | |
| Does np.random.choice but ensuring we don't have duplicates in the final result | |
| Args: | |
| items (np.array): _description_ | |
| weights (np.array): _description_ | |
| n (int): _description_ | |
| Returns: | |
| np.array: _description_ | |
| """ | |
| indices = np.arange(len(items)) | |
| out_indices = [] | |
| for _ in range(n): | |
| chosen_index = np.random.choice(indices, p=weights) | |
| out_indices.append(chosen_index) | |
| mask = indices != chosen_index | |
| indices = indices[mask] | |
| weights = weights[mask] | |
| if weights.sum() != 0: | |
| weights = weights / weights.sum() | |
| return items[out_indices] | |