Francesco commited on
Commit
3444304
1 Parent(s): e3c4cb8

added new way to sample songs that prevents duplicates

Browse files
Files changed (1) hide show
  1. utils.py +30 -0
utils.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+
4
+ def weighted_random_sample(items: np.array, weights: np.array, n: int) -> np.array:
5
+ """
6
+ Does np.random.choice but ensuring we don't have duplicates in the final result
7
+
8
+ Args:
9
+ items (np.array): _description_
10
+ weights (np.array): _description_
11
+ n (int): _description_
12
+
13
+ Returns:
14
+ np.array: _description_
15
+ """
16
+ indices = np.arange(len(items))
17
+ out_indices = []
18
+
19
+ for _ in range(n):
20
+ chosen_index = np.random.choice(indices, p=weights)
21
+ out_indices.append(chosen_index)
22
+
23
+ mask = indices != chosen_index
24
+ indices = indices[mask]
25
+ weights = weights[mask]
26
+
27
+ if weights.sum() != 0:
28
+ weights = weights / weights.sum()
29
+
30
+ return items[out_indices]