ccolas commited on
Commit
cb32cb9
1 Parent(s): eb07038

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +187 -155
app.py CHANGED
@@ -1,158 +1,190 @@
1
- import streamlit as st
2
- import numpy as np
3
- import os
4
- import pickle
5
- import spotipy
6
- import spotipy.util as sp_util
7
- import requests
8
 
 
9
  dir_path = os.path.dirname(os.path.realpath(__file__))
10
-
11
- # current mess: https://github.com/plamere/spotipy/issues/632
12
- def centered_button(func, text, n_columns=7, args=None):
13
- columns = st.columns(np.ones(n_columns))
14
- with columns[n_columns//2]:
15
- return func(text)
16
-
17
- # get credentials
18
- def setup_credentials():
19
- if 'client_id' in os.environ.keys() and 'client_secret' in os.environ.keys():
20
- client_info = dict(client_id=os.environ['client_id'],
21
- client_secret=os.environ['client_secret'])
22
- else:
23
- with open(dir_path + "/ids.pk", 'rb') as f:
24
- client_info = pickle.load(f)
25
-
26
- os.environ['SPOTIPY_CLIENT_ID'] = client_info['client_id']
27
- os.environ['SPOTIPY_CLIENT_SECRET'] = client_info['client_secret']
28
- os.environ['SPOTIPY_REDIRECT_URI'] = 'http://localhost:8080/'
29
- return client_info
30
-
31
- relevant_audio_features = ["danceability", "energy", "loudness", "mode", "valence", "tempo"]
32
-
33
-
34
- def get_client():
35
- scope = "playlist-modify-public"
36
- token = sp_util.prompt_for_user_token(scope=scope)
37
- sp = spotipy.Spotify(auth=token)
38
- user_id = sp.me()['id']
39
- return sp, user_id
40
-
41
-
42
- def new_get_client(session):
43
- scope = "playlist-modify-public"
44
- print('Here3')
45
-
46
- cache_handler = StreamlitCacheHandler(session)
47
- auth_manager = spotipy.oauth2.SpotifyOAuth(scope=scope,
48
- cache_handler=cache_handler,
49
- show_dialog=True)
50
- print('Here4')
51
- sp, user_id = None, None
52
-
53
- if requests.get("code"):
54
- # Step 2. Being redirected from Spotify auth page
55
- auth_manager.get_access_token(requests.get("code"))
56
- print('Here7')
57
- sp = spotipy.Spotify(auth_manager=auth_manager)
58
- user_id = sp.me()['id']
59
-
60
-
61
- if not auth_manager.validate_token(cache_handler.get_cached_token()):
62
- print('Here6')
63
- # Step 1. Display sign in link when no token
64
- auth_url = auth_manager.get_authorize_url()
65
- st.markdown(f'[Click here to log in]({auth_url})', unsafe_allow_html=True)
66
- return sp, user_id
67
-
68
-
69
-
70
- def extract_uris_from_links(links, url_type):
71
- assert url_type in ['playlist', 'artist', 'user']
72
- urls = links.split('\n')
73
- uris = []
74
- for url in urls:
75
- if 'playlist' in url:
76
- uri = url.split(f'{url_type}/')[-1].split('?')[0]
77
- else:
78
- uri = url.split('?')[0]
79
- uris.append(uri)
80
- return uris
81
-
82
- def wall_of_checkboxes(labels, max_width=10):
83
- n_labels = len(labels)
84
- n_rows = int(np.ceil(n_labels/max_width))
85
- checkboxes = []
86
- for i in range(n_rows):
87
- columns = st.columns(np.ones(max_width))
88
- row_length = n_labels % max_width if i == n_rows - 1 else max_width
89
- for j in range(row_length):
90
- with columns[j]:
91
- checkboxes.append(st.empty())
92
- return checkboxes
93
-
94
- def aggregate_genres(genres, legit_genres, verbose=False):
95
- genres_output = dict()
96
- legit_genres_formatted = [lg.replace('-', '').replace(' ', '') for lg in legit_genres]
97
- for glabel in genres.keys():
98
- if verbose: print('\n', glabel)
99
- glabel_formatted = glabel.replace(' ', '').replace('-', '')
100
- best_match = None
101
- best_match_score = 0
102
- for legit_glabel, legit_glabel_formatted in zip(legit_genres, legit_genres_formatted):
103
- if 'jazz' in glabel_formatted:
104
- best_match = 'jazz'
105
- if verbose: print('\t', 'pop')
106
- break
107
- if 'ukpop' in glabel_formatted:
108
- best_match = 'pop'
109
- if verbose: print('\t', 'pop')
110
- break
111
- if legit_glabel_formatted == glabel_formatted:
112
- if verbose: print('\t', legit_glabel_formatted)
113
- best_match = legit_glabel
114
- break
115
- elif glabel_formatted in legit_glabel_formatted:
116
- if verbose: print('\t', legit_glabel_formatted)
117
- if len(glabel_formatted) > best_match_score:
118
- best_match = legit_glabel
119
- best_match_score = len(glabel_formatted)
120
- elif legit_glabel_formatted in glabel_formatted:
121
- if verbose: print('\t', legit_glabel_formatted)
122
- if len(legit_glabel_formatted) > best_match_score:
123
- best_match = legit_glabel
124
- best_match_score = len(legit_glabel_formatted)
125
-
126
- if best_match is not None:
127
- if verbose: print('\t', '-->', best_match)
128
- if best_match in genres_output.keys():
129
- genres_output[best_match] += genres[glabel]
130
  else:
131
- genres_output[best_match] = genres[glabel]
132
- return genres_output
133
-
134
-
135
-
136
- class StreamlitCacheHandler(spotipy.cache_handler.CacheHandler):
137
- """
138
- A cache handler that stores the token info in the session framework
139
- provided by streamlit.
140
- """
141
-
142
- def __init__(self, session):
143
- self.session = session
144
-
145
- def get_cached_token(self):
146
- token_info = None
147
- try:
148
- token_info = self.session["token_info"]
149
- except KeyError:
150
- print("Token not found in the session")
151
-
152
- return token_info
153
-
154
- def save_token_to_cache(self, token_info):
155
- try:
156
- self.session["token_info"] = token_info
157
- except Exception as e:
158
- print("Error saving token to cache: " + str(e))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ from utils import *
3
+ from app_utils import *
 
 
 
 
4
 
5
+ debug = False
6
  dir_path = os.path.dirname(os.path.realpath(__file__))
7
+ # os.environ['FLASK_APP'] = dir_path + 'app2.py'
8
+ # if debug: os.environ['FLASK_ENV'] = 'development'
9
+ #
10
+ #
11
+ # app = Flask(__name__)
12
+ # app.config['SECRET_KEY'] = os.urandom(64)
13
+ # app.config['SESSION_TYPE'] = 'filesystem'
14
+ # app.config['SESSION_FILE_DIR'] = './.flask_session/'
15
+ # Session(app)
16
+ st.set_page_config(
17
+ page_title="EmotionPlaylist",
18
+ page_icon="🎧",
19
+ )
20
+ st.title('Customize Emotional Playlists')
21
+
22
+
23
+
24
+ def setup_streamlite():
25
+ setup_credentials()
26
+
27
+ print('Here1')
28
+ image = Image.open(dir_path + '/image.png')
29
+ st.image(image)
30
+ st.markdown("This app let's you quickly build playlists in a customized way: ")
31
+ st.markdown("* **It's easy**: you won't have to add songs one by one,\n"
32
+ "* **You're in control**: you provide a source of candidate songs, select a list of genres and choose the mood for the playlist.")
33
+
34
+ st.subheader("Step 1: Connect to your Spotify app")
35
+ st.markdown("Log into your Spotify account to let the app create the custom playlist.")
36
+ print('Here2')
37
+ if 'login' not in st.session_state:
38
+ sp, user_id = new_get_client(session=st.session_state)
39
+ if sp != None:
40
+ print("USER", user_id)
41
+ legit_genres = sp.recommendation_genre_seeds()['genres']
42
+ st.session_state['login'] = (sp, user_id, legit_genres)
43
+
44
+ # if 'login' not in st.session_state:
45
+ # login = centered_button(st.button, 'Log in', n_columns=7)
46
+ # if login or debug:
47
+ # sp, user_id = get_client(session=st.session_state)
48
+ # user_id = sp.me()['id']
49
+ # legit_genres = sp.recommendation_genre_seeds()['genres']
50
+ # st.session_state['login'] = (sp, user_id, legit_genres)
51
+
52
+ if 'login' in st.session_state or debug:
53
+ print('Here8')
54
+ if not debug: sp, user_id, legit_genres = st.session_state['login']
55
+ st.success('You are logged in.')
56
+
57
+ st.subheader("Step 2: Select candidate songs")
58
+ st.markdown("This can be done in three ways: \n"
59
+ "1. Get songs from a list of artists\n"
60
+ "2. Get songs from a list of users (and their playlists)\n"
61
+ "3. Get songs from a list of playlists.\n"
62
+ "For this you'll need to collect the urls of artists, users and/or playlists by clicking on 'Share' and copying the urls."
63
+ "You need to provide at least one source of music.")
64
+
65
+ label_artist = "Add a list of artist urls, one per line (optional)"
66
+ artists_links = st.text_area(label_artist, value="")
67
+ users_playlists = "Add a list of users urls, one per line (optional)"
68
+ users_links = st.text_area(users_playlists, value="")
69
+ label_playlists = "Add a list of playlists urls, one per line (optional)"
70
+ playlist_links = st.text_area(label_playlists, value="https://open.spotify.com/playlist/1H7a4q8JZArMQiidRy6qon?si=529184bbe93c4f73")
71
+
72
+ button = centered_button(st.button, 'Extract music', n_columns=5)
73
+ if button or debug:
74
+ if playlist_links != "":
75
+ playlist_uris = extract_uris_from_links(playlist_links, url_type='playlist')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  else:
77
+ raise ValueError('Please enter a list of playlists')
78
+ # Scanning playlists
79
+ st.spinner(text="Scanning music sources..")
80
+ data_tracks = get_all_tracks_from_playlists(sp, playlist_uris, verbose=True)
81
+ st.success(f'{len(data_tracks.keys())} tracks found!')
82
+
83
+ # Extract audio features
84
+ st.spinner(text="Extracting audio features..")
85
+ all_tracks_uris = np.array(list(data_tracks.keys()))
86
+ all_audio_features = [data_tracks[uri]['track']['audio_features'] for uri in all_tracks_uris]
87
+ all_tracks_audio_features = dict(zip(relevant_audio_features, [[audio_f[k] for audio_f in all_audio_features] for k in relevant_audio_features]))
88
+ genres = dict()
89
+ for index, uri in enumerate(all_tracks_uris):
90
+ track = data_tracks[uri]
91
+ track_genres = track['track']['genres']
92
+ for g in track_genres:
93
+ if g not in genres.keys():
94
+ genres[g] = [index]
95
+ else:
96
+ genres[g].append(index)
97
+ genres = aggregate_genres(genres, legit_genres)
98
+ genres_labels = sorted(genres.keys())
99
+ st.success(f'Audio features extracted!')
100
+ st.session_state['music_extracted'] = dict(all_tracks_uris=all_tracks_uris,
101
+ all_tracks_audio_features=all_tracks_audio_features,
102
+ genres=genres,
103
+ genres_labels=genres_labels)
104
+
105
+ if 'music_extracted' in st.session_state.keys():
106
+ all_tracks_uris = st.session_state['music_extracted']['all_tracks_uris']
107
+ all_tracks_audio_features = st.session_state['music_extracted']['all_tracks_audio_features']
108
+ genres = st.session_state['music_extracted']['genres']
109
+ genres_labels = st.session_state['music_extracted']['genres_labels']
110
+
111
+ st.subheader("Step 3: Customize it!")
112
+ st.markdown("##### Which genres?")
113
+ st.markdown("Check boxes to select genres, see how many tracks were selected below. Note: to check all, first uncheck all (bug).")
114
+ columns = st.columns(np.ones(5))
115
+ with columns[1]:
116
+ check_all = st.button('Check all')
117
+ with columns[3]:
118
+ uncheck_all = st.button('Uncheck all')
119
+
120
+ if 'checkboxes' not in st.session_state.keys():
121
+ st.session_state['checkboxes'] = [True] * len(genres_labels)
122
+
123
+ empty_checkboxes = wall_of_checkboxes(genres_labels, max_width=5)
124
+ if check_all:
125
+ st.session_state['checkboxes'] = [True] * len(genres_labels)
126
+ if uncheck_all:
127
+ st.session_state['checkboxes'] = [False] * len(genres_labels)
128
+ for i_emc, emc in enumerate(empty_checkboxes):
129
+ st.session_state['checkboxes'][i_emc] = emc.checkbox(genres_labels[i_emc], value=st.session_state['checkboxes'][i_emc])
130
+
131
+
132
+ # filter songs by genres
133
+ selected_labels = [genres_labels[i] for i in range(len(genres_labels)) if st.session_state['checkboxes'][i]]
134
+ genre_selected_indexes = []
135
+ for label in selected_labels:
136
+ genre_selected_indexes += genres[label]
137
+ genre_selected_indexes = np.array(sorted(set(genre_selected_indexes)))
138
+ if len(genre_selected_indexes) < 10:
139
+ st.warning('Please select more genres or add more music sources.')
140
+ else:
141
+ st.success(f'{len(genre_selected_indexes)} candidate tracks selected.')
142
+
143
+ st.markdown("##### What's the mood?")
144
+ valence = st.slider('Valence (0 negative, 100 positive)', min_value=0, max_value=100, value=100, step=1) / 100
145
+ energy = st.slider('Energy (0 low, 100 high)', min_value=0, max_value=100, value=100, step=1) / 100
146
+ danceability = st.slider('Danceability (0 low, 100 high)', min_value=0, max_value=100, value=100, step=1) / 100
147
+
148
+ target_mood = np.array([valence, energy, danceability]).reshape(1, 3)
149
+ candidate_moods = np.array([np.array(all_tracks_audio_features[feature])[genre_selected_indexes] for feature in ['valence', 'energy', 'danceability']]).T
150
+
151
+ distances = np.sqrt(((candidate_moods - target_mood) ** 2).sum(axis=1))
152
+ min_dist_indexes = np.argsort(distances)
153
+
154
+ n_candidates = distances.shape[0]
155
+ if n_candidates < 25:
156
+ st.warning('Please add more music sources or select more genres.')
157
+ else:
158
+ playlist_length = st.number_input(f'Pick a playlist length, given {n_candidates} candidates.', min_value=5,
159
+ value=min(10, n_candidates//5), max_value=n_candidates//5)
160
+
161
+ selected_tracks_indexes = genre_selected_indexes[min_dist_indexes[:playlist_length]]
162
+ selected_tracks_uris = all_tracks_uris[selected_tracks_indexes]
163
+ np.random.shuffle(selected_tracks_uris)
164
+ playlist_name = st.text_input('Playlist name', value='Mood Playlist')
165
+ if playlist_name == '':
166
+ st.warning('Please enter a playlist name.')
167
+ else:
168
+ generation_button = centered_button(st.button, 'Generate playlist', n_columns=5)
169
+ if generation_button:
170
+ description = f'Emotion Playlist for Valence: {valence}, Energy: {energy}, Danceability: {danceability}). ' \
171
+ f'Playlist generated by the EmotionPlaylist app: https://huggingface.co/spaces/ccolas/EmotionPlaylist.'
172
+ playlist_info = sp.user_playlist_create(user_id, playlist_name, public=True, collaborative=False, description=description)
173
+ playlist_uri = playlist_info['uri'].split(':')[-1]
174
+ sp.playlist_add_items(playlist_uri, selected_tracks_uris)
175
+ st.write(
176
+ f"""
177
+ <html>
178
+ <body>
179
+ <center>
180
+ <iframe style = "border-radius:12px" src="https://open.spotify.com/embed/playlist/{playlist_uri}" allowtransparency="true"
181
+ allow="encrypted-media" width="80%" height="580" frameborder="0"></iframe></center></body></html>
182
+ """, unsafe_allow_html=True)
183
+
184
+ st.success(f'The playlist has been generated, find it [here](https://open.spotify.com/playlist/{playlist_uri}).')
185
+
186
+
187
+ stop = 1
188
+
189
+ if __name__ == '__main__':
190
+ setup_streamlite()