damilojohn's picture
Update app.py
b1deadc
raw
history blame contribute delete
No virus
5.19 kB
# importing libraries
import gradio as gr
import pickle
import pandas as pd
from huggingface_hub import hf_hub_download
import json
import requests
pd.options.mode.chained_assignment = None # Turn off SettingWithCopyWarning
songs_df = pd.read_csv(
hf_hub_download('damilojohn/Personal_Playlist_Generator',
repo_type='dataset',
filename='spotify_transformed.csv'))
mappings = pd.read_csv(
hf_hub_download('damilojohn/Personal_Playlist_Generator',
repo_type='dataset',
filename='song_mappings.csv'))
verses_df = pd.read_csv(
hf_hub_download('damilojohn/Personal_Playlist_Generator',
repo_type='dataset',
filename='verses.csv'))
song_embeddings = pickle.load(
open(hf_hub_download('damilojohn/Personal_Playlist_Generator',
repo_type='dataset',
filename='embeddings.pkl'), 'rb'))
verses_df.rename(columns={'0': 'verse'}, inplace=True)
mappings.rename(columns={'Unnamed: 0': 'verse', '0': 'song_name'},
inplace=True)
def generate_playlist(prompt):
payload = {'prompt': prompt}
response = requests.request('POST',
url='https://xi5j0hwh1a.execute-api.eu-west-2.amazonaws.com/test/huh',
data=json.dumps(payload)).json()
hits = response['hits']
hits = pd.DataFrame.from_dict(hits[0])
verses_match = verses_df.iloc[hits['corpus_id']]
songs_match = mappings[mappings['verse'].isin(
verses_match['verse'].values)]
songs_match = songs_df[songs_df['song_name'].isin(
songs_match['song_name'].values)]
songs_match = songs_match.sort_values('song_name')
songs_match = songs_match.drop_duplicates(subset='song_name')
songs_name = list(songs_match['song_name'][:9])
cover_art = list(songs_match['image'][:9])
images = [gr.Image.update(value=art, visible=True) for art in cover_art]
return (gr.Radio.update(label='songs', interactive=True,
choices=songs_name),
*images)
def set_example_prompt(examples):
return gr.TextArea.update(value=examples[0])
def create_frontend():
demo = gr.Blocks()
with demo:
gr.Markdown(
'''
# A Text based playlist Generator for Afrobeats
'''
)
with gr.Row():
with gr.Column():
gr.Markdown(
'''
Enter words describing your playlist
'''
)
song_prompt = gr.TextArea(
value='',
placeholder=" Enter a sentence that describes how you're feeling or what you want your playlist to be about "
)
example_prompts = gr.Dataset(
components=[song_prompt],
samples=[
['heartbreak'],
['love at the beach'],
['uncertainty and bleak hopes']
]
)
with gr.Column():
fetch_songs = gr.Button(
value='Enter to see playlist',).style(full_width=True)
with gr.Column():
song_options = gr.Radio(label='songs', interactive=True,
choices=None, type='value',
visible=True)
with gr.Column():
with gr.Row():
tile1 = gr.Image(value="songs_cover.jpg",
show_label=False, visible=True)
tile2 = gr.Image(value="songs_cover.jpg",
show_label=False, visible=True)
tile3 = gr.Image(value="songs_cover.jpg",
show_label=False, visible=True)
with gr.Row():
tile4 = gr.Image(value="songs_cover.jpg",
show_label=False, visible=True)
tile5 = gr.Image(value="songs_cover.jpg",
show_label=False, visible=True)
tile6 = gr.Image(value="songs_cover.jpg",
show_label=False, visible=True)
with gr.Row():
tile7 = gr.Image(value="songs_cover.jpg",
show_label=False, visible=True)
tile8 = gr.Image(value='songs_cover.jpg',
show_label=False, visible=True)
tiles = [tile1, tile2, tile3, tile4, tile5, tile6, tile7, tile8]
fetch_songs.click(
fn=generate_playlist,
inputs=[song_prompt],
outputs=[song_options, *tiles]
)
example_prompts.click(
fn=set_example_prompt,
inputs=[example_prompts],
outputs=example_prompts.components
)
demo.launch(debug=True)
def main():
create_frontend()
if __name__ == "__main__":
main()