ccolas commited on
Commit
1767f84
1 Parent(s): 013be66

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +257 -0
app.py ADDED
@@ -0,0 +1,257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from PIL import Image
3
+ import spotipy.util as util
4
+ import pickle
5
+ import spotipy
6
+ from utils import *
7
+
8
+ st.set_page_config(
9
+ page_title="EmotionPlaylist",
10
+ page_icon="🎧",
11
+ )
12
+ debug = False
13
+ dir_path = os.path.dirname(os.path.realpath(__file__))
14
+
15
+ st.title('Customize Emotional Playlists')
16
+
17
+ def centered_button(func, text, n_columns=7, args=None):
18
+ columns = st.columns(np.ones(n_columns))
19
+ with columns[n_columns//2]:
20
+ return func(text)
21
+
22
+ # get credentials
23
+ def setup_credentials():
24
+ if 'client_id' in os.environ.keys() and 'client_secret' in os.environ.keys():
25
+ client_info = dict(client_id=os.environ['client_id'],
26
+ client_secret=os.environ['client_secret'])
27
+ else:
28
+ with open(dir_path + "/ids.pk", 'rb') as f:
29
+ client_info = pickle.load(f)
30
+
31
+ os.environ['SPOTIPY_CLIENT_ID'] = client_info['client_id']
32
+ os.environ['SPOTIPY_CLIENT_SECRET'] = client_info['client_secret']
33
+ os.environ['SPOTIPY_REDIRECT_URI'] = 'http://localhost:8080/'
34
+
35
+ relevant_audio_features = ["danceability", "energy", "loudness", "mode", "valence", "tempo"]
36
+
37
+
38
+ def get_client():
39
+ scope = "playlist-modify-public"
40
+ token = util.prompt_for_user_token(scope=scope)
41
+ sp = spotipy.Spotify(auth=token)
42
+ user_id = sp.me()['id']
43
+ return sp, user_id
44
+
45
+ def extract_uris_from_links(links, url_type):
46
+ assert url_type in ['playlist', 'artist', 'user']
47
+ urls = links.split('\n')
48
+ uris = []
49
+ for url in urls:
50
+ if 'playlist' in url:
51
+ uri = url.split(f'{url_type}/')[-1].split('?')[0]
52
+ else:
53
+ uri = url.split('?')[0]
54
+ uris.append(uri)
55
+ return uris
56
+
57
+ def wall_of_checkboxes(labels, max_width=10):
58
+ n_labels = len(labels)
59
+ n_rows = int(np.ceil(n_labels/max_width))
60
+ checkboxes = []
61
+ for i in range(n_rows):
62
+ columns = st.columns(np.ones(max_width))
63
+ row_length = n_labels % max_width if i == n_rows - 1 else max_width
64
+ for j in range(row_length):
65
+ with columns[j]:
66
+ checkboxes.append(st.empty())
67
+ return checkboxes
68
+
69
+ def aggregate_genres(genres, legit_genres, verbose=False):
70
+ genres_output = dict()
71
+ legit_genres_formatted = [lg.replace('-', '').replace(' ', '') for lg in legit_genres]
72
+ for glabel in genres.keys():
73
+ if verbose: print('\n', glabel)
74
+ glabel_formatted = glabel.replace(' ', '').replace('-', '')
75
+ best_match = None
76
+ best_match_score = 0
77
+ for legit_glabel, legit_glabel_formatted in zip(legit_genres, legit_genres_formatted):
78
+ if 'jazz' in glabel_formatted:
79
+ best_match = 'jazz'
80
+ if verbose: print('\t', 'pop')
81
+ break
82
+ if 'ukpop' in glabel_formatted:
83
+ best_match = 'pop'
84
+ if verbose: print('\t', 'pop')
85
+ break
86
+ if legit_glabel_formatted == glabel_formatted:
87
+ if verbose: print('\t', legit_glabel_formatted)
88
+ best_match = legit_glabel
89
+ break
90
+ elif glabel_formatted in legit_glabel_formatted:
91
+ if verbose: print('\t', legit_glabel_formatted)
92
+ if len(glabel_formatted) > best_match_score:
93
+ best_match = legit_glabel
94
+ best_match_score = len(glabel_formatted)
95
+ elif legit_glabel_formatted in glabel_formatted:
96
+ if verbose: print('\t', legit_glabel_formatted)
97
+ if len(legit_glabel_formatted) > best_match_score:
98
+ best_match = legit_glabel
99
+ best_match_score = len(legit_glabel_formatted)
100
+
101
+ if best_match is not None:
102
+ if verbose: print('\t', '-->', best_match)
103
+ if best_match in genres_output.keys():
104
+ genres_output[best_match] += genres[glabel]
105
+ else:
106
+ genres_output[best_match] = genres[glabel]
107
+ return genres_output
108
+
109
+ def setup_streamlite():
110
+ setup_credentials()
111
+ image = Image.open(dir_path + '/image.png')
112
+ st.image(image)
113
+
114
+ st.markdown("This app let's you quickly build playlists in a customized way: ")
115
+ st.markdown("* **It's easy**: you won't have to add songs one by one,\n"
116
+ "* **You're in control**: you provide a source of candidate songs, select a list of genres and choose the mood for the playlist.")
117
+
118
+ st.subheader("Step 1: Connect to your Spotify app")
119
+ st.markdown("Log into your Spotify account to let the app create the custom playlist.")
120
+ if 'login' not in st.session_state:
121
+ login = centered_button(st.button, 'Log in', n_columns=7)
122
+ if login or debug:
123
+ sp, user_id = get_client()
124
+ legit_genres = sp.recommendation_genre_seeds()['genres']
125
+ st.session_state['login'] = (sp, user_id, legit_genres)
126
+
127
+ if 'login' in st.session_state or debug:
128
+ if not debug: sp, user_id, legit_genres = st.session_state['login']
129
+ st.success('You are logged in.')
130
+
131
+ st.subheader("Step 2: Select candidate songs")
132
+ st.markdown("This can be done in three ways: \n"
133
+ "1. Get songs from a list of artists\n"
134
+ "2. Get songs from a list of users (and their playlists)\n"
135
+ "3. Get songs from a list of playlists.\n"
136
+ "For this you'll need to collect the urls of artists, users and/or playlists by clicking on 'Share' and copying the urls."
137
+ "You need to provide at least one source of music.")
138
+
139
+ label_artist = "Add a list of artist urls, one per line (optional)"
140
+ artists_links = st.text_area(label_artist, value="")
141
+ users_playlists = "Add a list of users urls, one per line (optional)"
142
+ users_links = st.text_area(users_playlists, value="")
143
+ label_playlists = "Add a list of playlists urls, one per line (optional)"
144
+ playlist_links = st.text_area(label_playlists, value="https://open.spotify.com/playlist/1H7a4q8JZArMQiidRy6qon?si=529184bbe93c4f73")
145
+
146
+ button = centered_button(st.button, 'Extract music', n_columns=5)
147
+ if button or debug:
148
+ if playlist_links != "":
149
+ playlist_uris = extract_uris_from_links(playlist_links, url_type='playlist')
150
+ else:
151
+ raise ValueError('Please enter a list of playlists')
152
+ # Scanning playlists
153
+ st.spinner(text="Scanning music sources..")
154
+ data_tracks = get_all_tracks_from_playlists(sp, playlist_uris, verbose=True)
155
+ st.success(f'{len(data_tracks.keys())} tracks found!')
156
+
157
+ # Extract audio features
158
+ st.spinner(text="Extracting audio features..")
159
+ all_tracks_uris = np.array(list(data_tracks.keys()))
160
+ all_audio_features = [data_tracks[uri]['track']['audio_features'] for uri in all_tracks_uris]
161
+ all_tracks_audio_features = dict(zip(relevant_audio_features, [[audio_f[k] for audio_f in all_audio_features] for k in relevant_audio_features]))
162
+ genres = dict()
163
+ for index, uri in enumerate(all_tracks_uris):
164
+ track = data_tracks[uri]
165
+ track_genres = track['track']['genres']
166
+ for g in track_genres:
167
+ if g not in genres.keys():
168
+ genres[g] = [index]
169
+ else:
170
+ genres[g].append(index)
171
+ genres = aggregate_genres(genres, legit_genres)
172
+ genres_labels = sorted(genres.keys())
173
+ st.success(f'Audio features extracted!')
174
+ st.session_state['music_extracted'] = dict(all_tracks_uris=all_tracks_uris,
175
+ all_tracks_audio_features=all_tracks_audio_features,
176
+ genres=genres,
177
+ genres_labels=genres_labels)
178
+
179
+ if 'music_extracted' in st.session_state.keys():
180
+ all_tracks_uris = st.session_state['music_extracted']['all_tracks_uris']
181
+ all_tracks_audio_features = st.session_state['music_extracted']['all_tracks_audio_features']
182
+ genres = st.session_state['music_extracted']['genres']
183
+ genres_labels = st.session_state['music_extracted']['genres_labels']
184
+
185
+ st.subheader("Step 3: Customize it!")
186
+ st.markdown("##### Which genres?")
187
+ st.markdown("Check boxes to select genres, see how many tracks were selected below. Note: to check all, first uncheck all (bug).")
188
+ columns = st.columns(np.ones(5))
189
+ with columns[1]:
190
+ check_all = st.button('Check all')
191
+ with columns[3]:
192
+ uncheck_all = st.button('Uncheck all')
193
+
194
+ if 'checkboxes' not in st.session_state.keys():
195
+ st.session_state['checkboxes'] = [True] * len(genres_labels)
196
+
197
+ empty_checkboxes = wall_of_checkboxes(genres_labels, max_width=5)
198
+ if check_all:
199
+ st.session_state['checkboxes'] = [True] * len(genres_labels)
200
+ if uncheck_all:
201
+ st.session_state['checkboxes'] = [False] * len(genres_labels)
202
+ for i_emc, emc in enumerate(empty_checkboxes):
203
+ st.session_state['checkboxes'][i_emc] = emc.checkbox(genres_labels[i_emc], value=st.session_state['checkboxes'][i_emc])
204
+
205
+
206
+ # filter songs by genres
207
+ selected_labels = [genres_labels[i] for i in range(len(genres_labels)) if st.session_state['checkboxes'][i]]
208
+ genre_selected_indexes = []
209
+ for label in selected_labels:
210
+ genre_selected_indexes += genres[label]
211
+ genre_selected_indexes = np.array(sorted(set(genre_selected_indexes)))
212
+ if len(genre_selected_indexes) < 10:
213
+ st.warning('Please select more genres or add more music sources.')
214
+ else:
215
+ st.success(f'{len(genre_selected_indexes)} candidate tracks selected.')
216
+
217
+ st.markdown("##### What's the mood?")
218
+ valence = st.slider('Valence (0 negative, 100 positive)', min_value=0, max_value=100, value=100, step=1) / 100
219
+ energy = st.slider('Energy (0 low, 100 high)', min_value=0, max_value=100, value=100, step=1) / 100
220
+ danceability = st.slider('Danceability (0 low, 100 high)', min_value=0, max_value=100, value=100, step=1) / 100
221
+
222
+ target_mood = np.array([valence, energy, danceability]).reshape(1, 3)
223
+ candidate_moods = np.array([np.array(all_tracks_audio_features[feature])[genre_selected_indexes] for feature in ['valence', 'energy', 'danceability']]).T
224
+
225
+ distances = np.sqrt(((candidate_moods - target_mood) ** 2).sum(axis=1))
226
+ min_dist_indexes = np.argsort(distances)
227
+
228
+ n_candidates = distances.shape[0]
229
+ if n_candidates < 25:
230
+ st.warning('Please add more music sources or select more genres.')
231
+ else:
232
+ playlist_length = st.number_input(f'Pick a playlist length, given {n_candidates} candidates.', min_value=5,
233
+ value=min(10, n_candidates//5), max_value=n_candidates//5)
234
+
235
+ selected_tracks_indexes = genre_selected_indexes[min_dist_indexes[:playlist_length]]
236
+ selected_tracks_uris = all_tracks_uris[selected_tracks_indexes]
237
+
238
+ playlist_name = st.text_input('Playlist name', value='Mood Playlist')
239
+ if playlist_name == '':
240
+ st.warning('Please enter a playlist name.')
241
+ else:
242
+ generation_button = st.button('Generate playlist')
243
+ if generation_button:
244
+ description = f'Emotion Playlist for Valence: {valence}, Energy: {energy}, Danceability: {danceability}). ' \
245
+ f'Playlist generated by the EmotionPlaylist app: https://huggingface.co/spaces/ccolas/EmotionPlaylist.'
246
+ playlist_info = sp.user_playlist_create(user_id, playlist_name, public=True, collaborative=False, description=description)
247
+ playlist_uri = playlist_info['uri'].split(':')[-1]
248
+ sp.playlist_add_items(playlist_uri, selected_tracks_uris)
249
+
250
+ st.success(f'The playlist has been generated, find it [here](https://open.spotify.com/playlist/{playlist_uri}).')
251
+
252
+
253
+ stop = 1
254
+
255
+
256
+ if __name__ == '__main__':
257
+ setup_streamlite()