EmotionPlaylist / app_utils.py
ccolas's picture
Update app_utils.py
9be7350
import streamlit as st
import numpy as np
import os
import pickle
import spotipy
import spotipy.util as sp_util
dir_path = os.path.dirname(os.path.realpath(__file__))
# current mess: https://github.com/plamere/spotipy/issues/632
def centered_button(func, text, n_columns=7, disabled=False, args=None):
columns = st.columns(np.ones(n_columns))
with columns[n_columns//2]:
if 'button' in str(func):
return func(text, disabled=disabled)
else:
return func(text)
# get credentials
def setup_credentials():
if 'client_id' in os.environ.keys() and 'client_secret' in os.environ.keys():
client_info = dict(client_id=os.environ['client_id'],
client_secret=os.environ['client_secret'])
else:
with open(dir_path + "/ids.pk", 'rb') as f:
client_info = pickle.load(f)
os.environ['SPOTIPY_CLIENT_ID'] = client_info['client_id']
os.environ['SPOTIPY_CLIENT_SECRET'] = client_info['client_secret']
os.environ['SPOTIPY_REDIRECT_URI'] = 'https://huggingface.co/spaces/ccolas/EmotionPlaylist/'
return client_info
relevant_audio_features = ["danceability", "energy", "loudness", "mode", "valence", "tempo"]
def get_client():
scope = "playlist-modify-public"
token = sp_util.prompt_for_user_token(scope=scope)
sp = spotipy.Spotify(auth=token)
user_id = sp.me()['id']
return sp, user_id
def add_button(url, text):
st.write(f'''
<center>
<a style='color:black;' href="{url}">
<button class='css-1cpxqw2'>
{text}
</button>
</a></center>
''',
unsafe_allow_html=True
)
def new_get_client(session):
scope = "playlist-modify-public"
cache_handler = StreamlitCacheHandler(session)
auth_manager = spotipy.oauth2.SpotifyOAuth(scope=scope,
cache_handler=cache_handler,
show_dialog=True)
sp, user_id = None, None
if not auth_manager.validate_token(cache_handler.get_cached_token()):
# Step 1. Display sign in link when no token
auth_url = auth_manager.get_authorize_url()
if 'code' not in st.experimental_get_query_params():
add_button(auth_url, 'Log in')
# st.markdown(f'<a href="{auth_url}" target="_self">Click here to log in</a>', unsafe_allow_html=True)
# Step 2. Being redirected from Spotify auth page
if 'code' in st.experimental_get_query_params():
auth_manager.get_access_token(st.experimental_get_query_params()['code'])
sp = spotipy.Spotify(auth_manager=auth_manager)
user_id = sp.me()['id']
return sp, user_id, auth_manager
def extract_uris_from_links(links, url_type):
assert url_type in ['playlist', 'artist', 'user']
urls = links.split('\n')
uris = []
for url in urls:
if 'playlist' in url:
uri = url.split(f'{url_type}/')[-1].split('?')[0]
elif 'user' in url:
uri = url.split(f'{url_type}/')[-1].split('?')[0]
else:
uri = url.split('?')[0]
uris.append(uri)
return uris
def wall_of_checkboxes(labels, max_width=10):
n_labels = len(labels)
n_rows = int(np.ceil(n_labels/max_width))
checkboxes = []
for i in range(n_rows):
columns = st.columns(np.ones(max_width))
row_length = n_labels % max_width if i == n_rows - 1 else max_width
for j in range(row_length):
with columns[j]:
checkboxes.append(st.empty())
return checkboxes
def find_legit_genre(glabel, legit_genres, verbose=False):
legit_genres_formatted = [lg.replace('-', '').replace(' ', '') for lg in legit_genres]
glabel_formatted = glabel.replace(' ', '').replace('-', '')
if verbose: print('\n', glabel)
best_match = None
best_match_score = 0
for legit_glabel, legit_glabel_formatted in zip(legit_genres, legit_genres_formatted):
if 'jazz' in glabel_formatted:
best_match = 'jazz'
if verbose: print('\t', 'pop')
break
if 'ukpop' in glabel_formatted:
best_match = 'pop'
if verbose: print('\t', 'pop')
break
if legit_glabel_formatted == glabel_formatted:
if verbose: print('\t', legit_glabel_formatted)
best_match = legit_glabel
break
elif glabel_formatted in legit_glabel_formatted:
if verbose: print('\t', legit_glabel_formatted)
if len(glabel_formatted) > best_match_score:
best_match = legit_glabel
best_match_score = len(glabel_formatted)
elif legit_glabel_formatted in glabel_formatted:
if verbose: print('\t', legit_glabel_formatted)
if len(legit_glabel_formatted) > best_match_score:
best_match = legit_glabel
best_match_score = len(legit_glabel_formatted)
if best_match is None:
return "unknown"
else:
return best_match
# def aggregate_genres(genres, legit_genres, verbose=False):
# genres_output = dict()
# legit_genres_formatted = [lg.replace('-', '').replace(' ', '') for lg in legit_genres]
# for glabel in genres.keys():
# if verbose: print('\n', glabel)
# glabel_formatted = glabel.replace(' ', '').replace('-', '')
# best_match = None
# best_match_score = 0
# for legit_glabel, legit_glabel_formatted in zip(legit_genres, legit_genres_formatted):
# if 'jazz' in glabel_formatted:
# best_match = 'jazz'
# if verbose: print('\t', 'pop')
# break
# if 'ukpop' in glabel_formatted:
# best_match = 'pop'
# if verbose: print('\t', 'pop')
# break
# if legit_glabel_formatted == glabel_formatted:
# if verbose: print('\t', legit_glabel_formatted)
# best_match = legit_glabel
# break
# elif glabel_formatted in legit_glabel_formatted:
# if verbose: print('\t', legit_glabel_formatted)
# if len(glabel_formatted) > best_match_score:
# best_match = legit_glabel
# best_match_score = len(glabel_formatted)
# elif legit_glabel_formatted in glabel_formatted:
# if verbose: print('\t', legit_glabel_formatted)
# if len(legit_glabel_formatted) > best_match_score:
# best_match = legit_glabel
# best_match_score = len(legit_glabel_formatted)
#
# if best_match is not None:
# if verbose: print('\t', '-->', best_match)
# if best_match in genres_output.keys():
# genres_output[best_match] += genres[glabel]
# else:
# genres_output[best_match] = genres[glabel]
# else:
# if "unknown" in genres_output.keys():
# genres_output["unknown"] += genres[glabel]
# else:
# genres_output["unknown"] = genres[glabel]
# for k in genres_output.keys():
# genres_output[k] = sorted(set(genres_output[k]))
# return genres_output
def get_all_playlists_uris_from_users(sp, user_ids):
all_uris = []
all_names = []
for user_id in user_ids:
print(user_id)
offset = 0
done = False
while not done:
playlist_list = sp.user_playlists(user_id, offset=offset, limit=50)
these_names = [p['name'] for p in playlist_list['items']]
these_uris = [p['uri'] for p in playlist_list['items']]
for name, uri in zip(these_names, these_uris):
if uri not in all_uris:
all_uris.append(uri)
all_names.append(user_id + '/' + name)
if len(playlist_list['items']) < offset:
done = True
else:
offset += 50
return all_uris, all_names
class StreamlitCacheHandler(spotipy.cache_handler.CacheHandler):
"""
A cache handler that stores the token info in the session framework
provided by streamlit.
"""
def __init__(self, session):
self.session = session
def get_cached_token(self):
token_info = None
try:
token_info = self.session["token_info"]
except KeyError:
print("Token not found in the session")
return token_info
def save_token_to_cache(self, token_info):
try:
self.session["token_info"] = token_info
except Exception as e:
print("Error saving token to cache: " + str(e))