ccolas commited on
Commit
0c95d20
1 Parent(s): 85abdd3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +155 -187
app.py CHANGED
@@ -1,190 +1,158 @@
1
- from PIL import Image
2
- from utils import *
3
- from app_utils import *
 
 
 
 
4
 
5
- debug = False
6
  dir_path = os.path.dirname(os.path.realpath(__file__))
7
- # os.environ['FLASK_APP'] = dir_path + 'app2.py'
8
- # if debug: os.environ['FLASK_ENV'] = 'development'
9
- #
10
- #
11
- # app = Flask(__name__)
12
- # app.config['SECRET_KEY'] = os.urandom(64)
13
- # app.config['SESSION_TYPE'] = 'filesystem'
14
- # app.config['SESSION_FILE_DIR'] = './.flask_session/'
15
- # Session(app)
16
- st.set_page_config(
17
- page_title="EmotionPlaylist",
18
- page_icon="🎧",
19
- )
20
- st.title('Customize Emotional Playlists')
21
-
22
-
23
-
24
- def setup_streamlite():
25
- setup_credentials()
26
-
27
- print('Here1')
28
- image = Image.open(dir_path + '/image.png')
29
- st.image(image)
30
- st.markdown("This app let's you quickly build playlists in a customized way: ")
31
- st.markdown("* **It's easy**: you won't have to add songs one by one,\n"
32
- "* **You're in control**: you provide a source of candidate songs, select a list of genres and choose the mood for the playlist.")
33
-
34
- st.subheader("Step 1: Connect to your Spotify app")
35
- st.markdown("Log into your Spotify account to let the app create the custom playlist.")
36
- print('Here2')
37
- if 'login' not in st.session_state:
38
- sp, user_id = new_get_client(session=st.session_state)
39
- if sp != None:
40
- print("USER", user_id)
41
- legit_genres = sp.recommendation_genre_seeds()['genres']
42
- st.session_state['login'] = (sp, user_id, legit_genres)
43
-
44
- # if 'login' not in st.session_state:
45
- # login = centered_button(st.button, 'Log in', n_columns=7)
46
- # if login or debug:
47
- # sp, user_id = get_client(session=st.session_state)
48
- # user_id = sp.me()['id']
49
- # legit_genres = sp.recommendation_genre_seeds()['genres']
50
- # st.session_state['login'] = (sp, user_id, legit_genres)
51
-
52
- if 'login' in st.session_state or debug:
53
- print('Here8')
54
- if not debug: sp, user_id, legit_genres = st.session_state['login']
55
- st.success('You are logged in.')
56
-
57
- st.subheader("Step 2: Select candidate songs")
58
- st.markdown("This can be done in three ways: \n"
59
- "1. Get songs from a list of artists\n"
60
- "2. Get songs from a list of users (and their playlists)\n"
61
- "3. Get songs from a list of playlists.\n"
62
- "For this you'll need to collect the urls of artists, users and/or playlists by clicking on 'Share' and copying the urls."
63
- "You need to provide at least one source of music.")
64
-
65
- label_artist = "Add a list of artist urls, one per line (optional)"
66
- artists_links = st.text_area(label_artist, value="")
67
- users_playlists = "Add a list of users urls, one per line (optional)"
68
- users_links = st.text_area(users_playlists, value="")
69
- label_playlists = "Add a list of playlists urls, one per line (optional)"
70
- playlist_links = st.text_area(label_playlists, value="https://open.spotify.com/playlist/1H7a4q8JZArMQiidRy6qon?si=529184bbe93c4f73")
71
-
72
- button = centered_button(st.button, 'Extract music', n_columns=5)
73
- if button or debug:
74
- if playlist_links != "":
75
- playlist_uris = extract_uris_from_links(playlist_links, url_type='playlist')
76
- else:
77
- raise ValueError('Please enter a list of playlists')
78
- # Scanning playlists
79
- st.spinner(text="Scanning music sources..")
80
- data_tracks = get_all_tracks_from_playlists(sp, playlist_uris, verbose=True)
81
- st.success(f'{len(data_tracks.keys())} tracks found!')
82
-
83
- # Extract audio features
84
- st.spinner(text="Extracting audio features..")
85
- all_tracks_uris = np.array(list(data_tracks.keys()))
86
- all_audio_features = [data_tracks[uri]['track']['audio_features'] for uri in all_tracks_uris]
87
- all_tracks_audio_features = dict(zip(relevant_audio_features, [[audio_f[k] for audio_f in all_audio_features] for k in relevant_audio_features]))
88
- genres = dict()
89
- for index, uri in enumerate(all_tracks_uris):
90
- track = data_tracks[uri]
91
- track_genres = track['track']['genres']
92
- for g in track_genres:
93
- if g not in genres.keys():
94
- genres[g] = [index]
95
- else:
96
- genres[g].append(index)
97
- genres = aggregate_genres(genres, legit_genres)
98
- genres_labels = sorted(genres.keys())
99
- st.success(f'Audio features extracted!')
100
- st.session_state['music_extracted'] = dict(all_tracks_uris=all_tracks_uris,
101
- all_tracks_audio_features=all_tracks_audio_features,
102
- genres=genres,
103
- genres_labels=genres_labels)
104
-
105
- if 'music_extracted' in st.session_state.keys():
106
- all_tracks_uris = st.session_state['music_extracted']['all_tracks_uris']
107
- all_tracks_audio_features = st.session_state['music_extracted']['all_tracks_audio_features']
108
- genres = st.session_state['music_extracted']['genres']
109
- genres_labels = st.session_state['music_extracted']['genres_labels']
110
-
111
- st.subheader("Step 3: Customize it!")
112
- st.markdown("##### Which genres?")
113
- st.markdown("Check boxes to select genres, see how many tracks were selected below. Note: to check all, first uncheck all (bug).")
114
- columns = st.columns(np.ones(5))
115
- with columns[1]:
116
- check_all = st.button('Check all')
117
- with columns[3]:
118
- uncheck_all = st.button('Uncheck all')
119
-
120
- if 'checkboxes' not in st.session_state.keys():
121
- st.session_state['checkboxes'] = [True] * len(genres_labels)
122
-
123
- empty_checkboxes = wall_of_checkboxes(genres_labels, max_width=5)
124
- if check_all:
125
- st.session_state['checkboxes'] = [True] * len(genres_labels)
126
- if uncheck_all:
127
- st.session_state['checkboxes'] = [False] * len(genres_labels)
128
- for i_emc, emc in enumerate(empty_checkboxes):
129
- st.session_state['checkboxes'][i_emc] = emc.checkbox(genres_labels[i_emc], value=st.session_state['checkboxes'][i_emc])
130
-
131
-
132
- # filter songs by genres
133
- selected_labels = [genres_labels[i] for i in range(len(genres_labels)) if st.session_state['checkboxes'][i]]
134
- genre_selected_indexes = []
135
- for label in selected_labels:
136
- genre_selected_indexes += genres[label]
137
- genre_selected_indexes = np.array(sorted(set(genre_selected_indexes)))
138
- if len(genre_selected_indexes) < 10:
139
- st.warning('Please select more genres or add more music sources.')
140
  else:
141
- st.success(f'{len(genre_selected_indexes)} candidate tracks selected.')
142
-
143
- st.markdown("##### What's the mood?")
144
- valence = st.slider('Valence (0 negative, 100 positive)', min_value=0, max_value=100, value=100, step=1) / 100
145
- energy = st.slider('Energy (0 low, 100 high)', min_value=0, max_value=100, value=100, step=1) / 100
146
- danceability = st.slider('Danceability (0 low, 100 high)', min_value=0, max_value=100, value=100, step=1) / 100
147
-
148
- target_mood = np.array([valence, energy, danceability]).reshape(1, 3)
149
- candidate_moods = np.array([np.array(all_tracks_audio_features[feature])[genre_selected_indexes] for feature in ['valence', 'energy', 'danceability']]).T
150
-
151
- distances = np.sqrt(((candidate_moods - target_mood) ** 2).sum(axis=1))
152
- min_dist_indexes = np.argsort(distances)
153
-
154
- n_candidates = distances.shape[0]
155
- if n_candidates < 25:
156
- st.warning('Please add more music sources or select more genres.')
157
- else:
158
- playlist_length = st.number_input(f'Pick a playlist length, given {n_candidates} candidates.', min_value=5,
159
- value=min(10, n_candidates//5), max_value=n_candidates//5)
160
-
161
- selected_tracks_indexes = genre_selected_indexes[min_dist_indexes[:playlist_length]]
162
- selected_tracks_uris = all_tracks_uris[selected_tracks_indexes]
163
- np.random.shuffle(selected_tracks_uris)
164
- playlist_name = st.text_input('Playlist name', value='Mood Playlist')
165
- if playlist_name == '':
166
- st.warning('Please enter a playlist name.')
167
- else:
168
- generation_button = centered_button(st.button, 'Generate playlist', n_columns=5)
169
- if generation_button:
170
- description = f'Emotion Playlist for Valence: {valence}, Energy: {energy}, Danceability: {danceability}). ' \
171
- f'Playlist generated by the EmotionPlaylist app: https://huggingface.co/spaces/ccolas/EmotionPlaylist.'
172
- playlist_info = sp.user_playlist_create(user_id, playlist_name, public=True, collaborative=False, description=description)
173
- playlist_uri = playlist_info['uri'].split(':')[-1]
174
- sp.playlist_add_items(playlist_uri, selected_tracks_uris)
175
- st.write(
176
- f"""
177
- <html>
178
- <body>
179
- <center>
180
- <iframe style = "border-radius:12px" src="https://open.spotify.com/embed/playlist/{playlist_uri}" allowtransparency="true"
181
- allow="encrypted-media" width="80%" height="580" frameborder="0"></iframe></center></body></html>
182
- """, unsafe_allow_html=True)
183
-
184
- st.success(f'The playlist has been generated, find it [here](https://open.spotify.com/playlist/{playlist_uri}).')
185
-
186
-
187
- stop = 1
188
-
189
- if __name__ == '__main__':
190
- setup_streamlite()
 
1
+ import streamlit as st
2
+ import numpy as np
3
+ import os
4
+ import pickle
5
+ import spotipy
6
+ import spotipy.util as sp_util
7
+ import requests
8
 
 
9
  dir_path = os.path.dirname(os.path.realpath(__file__))
10
+
11
+ # current mess: https://github.com/plamere/spotipy/issues/632
12
+ def centered_button(func, text, n_columns=7, args=None):
13
+ columns = st.columns(np.ones(n_columns))
14
+ with columns[n_columns//2]:
15
+ return func(text)
16
+
17
+ # get credentials
18
+ def setup_credentials():
19
+ if 'client_id' in os.environ.keys() and 'client_secret' in os.environ.keys():
20
+ client_info = dict(client_id=os.environ['client_id'],
21
+ client_secret=os.environ['client_secret'])
22
+ else:
23
+ with open(dir_path + "/ids.pk", 'rb') as f:
24
+ client_info = pickle.load(f)
25
+
26
+ os.environ['SPOTIPY_CLIENT_ID'] = client_info['client_id']
27
+ os.environ['SPOTIPY_CLIENT_SECRET'] = client_info['client_secret']
28
+ os.environ['SPOTIPY_REDIRECT_URI'] = 'http://localhost:8080/'
29
+ return client_info
30
+
31
+ relevant_audio_features = ["danceability", "energy", "loudness", "mode", "valence", "tempo"]
32
+
33
+
34
+ def get_client():
35
+ scope = "playlist-modify-public"
36
+ token = sp_util.prompt_for_user_token(scope=scope)
37
+ sp = spotipy.Spotify(auth=token)
38
+ user_id = sp.me()['id']
39
+ return sp, user_id
40
+
41
+
42
+ def new_get_client(session):
43
+ scope = "playlist-modify-public"
44
+ print('Here3')
45
+
46
+ cache_handler = StreamlitCacheHandler(session)
47
+ auth_manager = spotipy.oauth2.SpotifyOAuth(scope=scope,
48
+ cache_handler=cache_handler,
49
+ show_dialog=True)
50
+ print('Here4')
51
+ sp, user_id = None, None
52
+
53
+ if requests.get("code"):
54
+ # Step 2. Being redirected from Spotify auth page
55
+ auth_manager.get_access_token(requests.get("code"))
56
+ print('Here7')
57
+ sp = spotipy.Spotify(auth_manager=auth_manager)
58
+ user_id = sp.me()['id']
59
+
60
+
61
+ if not auth_manager.validate_token(cache_handler.get_cached_token()):
62
+ print('Here6')
63
+ # Step 1. Display sign in link when no token
64
+ auth_url = auth_manager.get_authorize_url()
65
+ st.markdown(f'[Click here to log in]({auth_url})', unsafe_allow_html=True)
66
+ return sp, user_id
67
+
68
+
69
+
70
+ def extract_uris_from_links(links, url_type):
71
+ assert url_type in ['playlist', 'artist', 'user']
72
+ urls = links.split('\n')
73
+ uris = []
74
+ for url in urls:
75
+ if 'playlist' in url:
76
+ uri = url.split(f'{url_type}/')[-1].split('?')[0]
77
+ else:
78
+ uri = url.split('?')[0]
79
+ uris.append(uri)
80
+ return uris
81
+
82
+ def wall_of_checkboxes(labels, max_width=10):
83
+ n_labels = len(labels)
84
+ n_rows = int(np.ceil(n_labels/max_width))
85
+ checkboxes = []
86
+ for i in range(n_rows):
87
+ columns = st.columns(np.ones(max_width))
88
+ row_length = n_labels % max_width if i == n_rows - 1 else max_width
89
+ for j in range(row_length):
90
+ with columns[j]:
91
+ checkboxes.append(st.empty())
92
+ return checkboxes
93
+
94
+ def aggregate_genres(genres, legit_genres, verbose=False):
95
+ genres_output = dict()
96
+ legit_genres_formatted = [lg.replace('-', '').replace(' ', '') for lg in legit_genres]
97
+ for glabel in genres.keys():
98
+ if verbose: print('\n', glabel)
99
+ glabel_formatted = glabel.replace(' ', '').replace('-', '')
100
+ best_match = None
101
+ best_match_score = 0
102
+ for legit_glabel, legit_glabel_formatted in zip(legit_genres, legit_genres_formatted):
103
+ if 'jazz' in glabel_formatted:
104
+ best_match = 'jazz'
105
+ if verbose: print('\t', 'pop')
106
+ break
107
+ if 'ukpop' in glabel_formatted:
108
+ best_match = 'pop'
109
+ if verbose: print('\t', 'pop')
110
+ break
111
+ if legit_glabel_formatted == glabel_formatted:
112
+ if verbose: print('\t', legit_glabel_formatted)
113
+ best_match = legit_glabel
114
+ break
115
+ elif glabel_formatted in legit_glabel_formatted:
116
+ if verbose: print('\t', legit_glabel_formatted)
117
+ if len(glabel_formatted) > best_match_score:
118
+ best_match = legit_glabel
119
+ best_match_score = len(glabel_formatted)
120
+ elif legit_glabel_formatted in glabel_formatted:
121
+ if verbose: print('\t', legit_glabel_formatted)
122
+ if len(legit_glabel_formatted) > best_match_score:
123
+ best_match = legit_glabel
124
+ best_match_score = len(legit_glabel_formatted)
125
+
126
+ if best_match is not None:
127
+ if verbose: print('\t', '-->', best_match)
128
+ if best_match in genres_output.keys():
129
+ genres_output[best_match] += genres[glabel]
 
 
 
 
 
 
 
 
 
 
 
 
 
130
  else:
131
+ genres_output[best_match] = genres[glabel]
132
+ return genres_output
133
+
134
+
135
+
136
+ class StreamlitCacheHandler(spotipy.cache_handler.CacheHandler):
137
+ """
138
+ A cache handler that stores the token info in the session framework
139
+ provided by streamlit.
140
+ """
141
+
142
+ def __init__(self, session):
143
+ self.session = session
144
+
145
+ def get_cached_token(self):
146
+ token_info = None
147
+ try:
148
+ token_info = self.session["token_info"]
149
+ except KeyError:
150
+ print("Token not found in the session")
151
+
152
+ return token_info
153
+
154
+ def save_token_to_cache(self, token_info):
155
+ try:
156
+ self.session["token_info"] = token_info
157
+ except Exception as e:
158
+ print("Error saving token to cache: " + str(e))