ccolas commited on
Commit
cda8a3a
1 Parent(s): 563e52e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -5
app.py CHANGED
@@ -253,7 +253,7 @@ def find_best_songs_for_mood(all_tracks_audio_features, genre_selected_indexes,
253
  return min_dist_indexes, n_candidates
254
 
255
  @st.cache
256
- def run_exploration(selected_tracks_uris, selected_tracks_genres, playlist_length, exploration, all_tracks_uris, target_mood):
257
  # sample exploration songs
258
  if exploration > 0:
259
  n_known = int(playlist_length * (1 - exploration))
@@ -276,19 +276,23 @@ def run_exploration(selected_tracks_uris, selected_tracks_genres, playlist_lengt
276
  dict_args_loose[f'max_{m}'] = min(1, target_mood[i_m] + 0.3)
277
  new_songs = []
278
  counter_seed = 0
 
279
  while len(new_songs) < n_new:
280
  try:
281
  print(seed_songs[counter_seed])
282
  print(dict_args)
283
- reco = sp.recommendations(seed_tracks=[seed_songs[counter_seed]], seed_genres=seed_genres[counter_seed],
 
284
  market="from_token", country='from_token', **dict_args)['tracks']
285
  if len(reco) == 0:
286
  print('Using loose bounds')
287
- reco = sp.recommendations(seed_tracks=[seed_songs[counter_seed]], seed_genres=seed_genres[counter_seed],
 
288
  market="from_token", country='from_token', **dict_args_loose)['tracks']
289
  if len(reco) == 0:
290
  print('Using looser bounds')
291
- reco = sp.recommendations(seed_tracks=[seed_songs[counter_seed]], seed_genres=seed_genres[counter_seed],
 
292
  market="from_token", country='from_token', **dict_args_looser)['tracks']
293
  if len(reco) == 0:
294
  print('Removing bounds')
@@ -298,6 +302,7 @@ def run_exploration(selected_tracks_uris, selected_tracks_genres, playlist_lengt
298
  if r['uri'] not in all_tracks_uris and r['uri'] not in new_songs:
299
  new_songs.append(r['uri'])
300
  break
 
301
  except:
302
  pass
303
  print(counter_seed, len(new_songs))
@@ -348,6 +353,7 @@ def run_app():
348
  if custom_button or 'run_custom' in st.session_state.keys() or debug:
349
  st.session_state['run_custom'] = True
350
  checkboxes = st.session_state['checkboxes'].copy()
 
351
  init_time = time.time()
352
  genre_selected_indexes = filter_songs_by_genre(checkboxes, genres_labels, indexes_by_genre)
353
  if len(genre_selected_indexes) < 10:
@@ -380,7 +386,7 @@ def run_app():
380
  generation_button = centered_button(st.button, 'Generate playlist', n_columns=5)
381
  if generation_button:
382
  selected_tracks_uris = run_exploration(selected_tracks_uris, selected_tracks_genres, playlist_length, exploration, all_tracks_uris,
383
- target_mood.flatten())
384
  print(f'9. run exploration: {time.time() - init_time:.2f}')
385
  init_time = time.time()
386
 
 
253
  return min_dist_indexes, n_candidates
254
 
255
  @st.cache
256
+ def run_exploration(selected_tracks_uris, selected_tracks_genres, playlist_length, exploration, all_tracks_uris, target_mood, selected_genres):
257
  # sample exploration songs
258
  if exploration > 0:
259
  n_known = int(playlist_length * (1 - exploration))
 
276
  dict_args_loose[f'max_{m}'] = min(1, target_mood[i_m] + 0.3)
277
  new_songs = []
278
  counter_seed = 0
279
+ print(selected_genres)
280
  while len(new_songs) < n_new:
281
  try:
282
  print(seed_songs[counter_seed])
283
  print(dict_args)
284
+ np.random.shuffle(selected_genres)
285
+ reco = sp.recommendations(seed_tracks=[seed_songs[counter_seed]], seed_genres=selected_genres,
286
  market="from_token", country='from_token', **dict_args)['tracks']
287
  if len(reco) == 0:
288
  print('Using loose bounds')
289
+ np.random.shuffle(selected_genres)
290
+ reco = sp.recommendations(seed_tracks=[seed_songs[counter_seed]], seed_genres=selected_genres,
291
  market="from_token", country='from_token', **dict_args_loose)['tracks']
292
  if len(reco) == 0:
293
  print('Using looser bounds')
294
+ np.random.shuffle(selected_genres)
295
+ reco = sp.recommendations(seed_tracks=[seed_songs[counter_seed]], seed_genres=selected_genres,
296
  market="from_token", country='from_token', **dict_args_looser)['tracks']
297
  if len(reco) == 0:
298
  print('Removing bounds')
 
302
  if r['uri'] not in all_tracks_uris and r['uri'] not in new_songs:
303
  new_songs.append(r['uri'])
304
  break
305
+
306
  except:
307
  pass
308
  print(counter_seed, len(new_songs))
 
353
  if custom_button or 'run_custom' in st.session_state.keys() or debug:
354
  st.session_state['run_custom'] = True
355
  checkboxes = st.session_state['checkboxes'].copy()
356
+ selected_genres = [genres_labels[i] for i in range(len(genres_labels)) if checkboxes[i] and genres_labels[i] != 'unknown']
357
  init_time = time.time()
358
  genre_selected_indexes = filter_songs_by_genre(checkboxes, genres_labels, indexes_by_genre)
359
  if len(genre_selected_indexes) < 10:
 
386
  generation_button = centered_button(st.button, 'Generate playlist', n_columns=5)
387
  if generation_button:
388
  selected_tracks_uris = run_exploration(selected_tracks_uris, selected_tracks_genres, playlist_length, exploration, all_tracks_uris,
389
+ target_mood.flatten(), selected_genres)
390
  print(f'9. run exploration: {time.time() - init_time:.2f}')
391
  init_time = time.time()
392