NimaBoscarino commited on
Commit
043d857
1 Parent(s): 2570d24

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +124 -0
app.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sentence_transformers import SentenceTransformer, util
2
+ from huggingface_hub import hf_hub_download
3
+ import pickle
4
+ import pandas as pd
5
+ import gradio as gr
6
+
7
+ pd.options.mode.chained_assignment = None # Turn off SettingWithCopyWarning
8
+
9
+ pickled = pickle.load(open(hf_hub_download("NimaBoscarino/playlist-generator", filename="clean-large_embeddings_msmarco-MiniLM-L-6-v3.pkl"), "rb"))
10
+ songs = pd.read_csv(hf_hub_download("NimaBoscarino/playlist-generator", filename="songs_new.csv"))
11
+ verses = pickle.load(open(hf_hub_download("NimaBoscarino/playlist-generator", filename="verses.pkl"), "rb"))
12
+ lyrics = pd.read_csv(hf_hub_download("NimaBoscarino/playlist-generator", filename="lyrics_new.csv"))
13
+
14
+ embedder = SentenceTransformer('msmarco-MiniLM-L-6-v3')
15
+
16
+ genius_ids = pickled["genius_ids"]
17
+ corpus_embeddings = pickled["embeddings"]
18
+
19
+
20
+ def generate_playlist(prompt):
21
+ prompt_embedding = embedder.encode(prompt, convert_to_tensor=True)
22
+ hits = util.semantic_search(prompt_embedding, corpus_embeddings, top_k=20)
23
+ hits = pd.DataFrame(hits[0], columns=['corpus_id', 'score'])
24
+
25
+ verse_match = verses.iloc[hits['corpus_id']]
26
+ verse_match = verse_match.drop_duplicates(subset=["genius_id"])
27
+ song_match = songs[songs["genius_id"].isin(verse_match["genius_id"].values)]
28
+ song_match.genius_id = pd.Categorical(song_match.genius_id, categories=verse_match["genius_id"].values)
29
+ song_match = song_match.sort_values("genius_id")
30
+ song_match = song_match[0:9] # Only grab the top 9
31
+
32
+ song_names = list(song_match["full_title"])
33
+ song_art = list(song_match["art"].fillna("https://i.imgur.com/bgCDfT1.jpg"))
34
+ images = [gr.Image.update(value=art, visible=True) for art in song_art]
35
+
36
+ return (
37
+ gr.Radio.update(label="Songs", interactive=True, choices=song_names),
38
+ *images
39
+ )
40
+
41
+
42
+ def set_lyrics(full_title):
43
+ lyrics_text = lyrics[lyrics["genius_id"].isin(songs[songs["full_title"] == full_title]["genius_id"])]["text"].iloc[0]
44
+ return gr.Textbox.update(value=lyrics_text)
45
+
46
+
47
+ def set_example_prompt(example):
48
+ return gr.TextArea.update(value=example[0])
49
+
50
+
51
+ demo = gr.Blocks()
52
+
53
+ with demo:
54
+ gr.Markdown(
55
+ """
56
+ # Playlist Generator 📻 🎵
57
+ """)
58
+
59
+ with gr.Row():
60
+ with gr.Column():
61
+ gr.Markdown(
62
+ """
63
+ Enter a prompt and generate a playlist based on ✨semantic similarity✨
64
+ This was built using Sentence Transformers and Gradio – [read more here!](#)
65
+ """)
66
+
67
+ song_prompt = gr.TextArea(
68
+ value="Running wild and free",
69
+ placeholder="Enter a song prompt, or choose an example"
70
+ )
71
+ example_prompts = gr.Dataset(
72
+ components=[song_prompt],
73
+ samples=[
74
+ ["I feel nostalgic for the past"],
75
+ ["Running wild and free"],
76
+ ["I'm deeply in love with someone I just met!"],
77
+ ["My friends mean the world to me"],
78
+ ["Sometimes I feel like no one understands"],
79
+ ]
80
+ )
81
+
82
+ with gr.Column():
83
+ fetch_songs = gr.Button(value="Generate Your Playlist 🧑🏽‍🎤").style(full_width=True)
84
+
85
+ with gr.Row():
86
+ tile1 = gr.Image(value="https://i.imgur.com/bgCDfT1.jpg", show_label=False, visible=True)
87
+ tile2 = gr.Image(value="https://i.imgur.com/bgCDfT1.jpg", show_label=False, visible=True)
88
+ tile3 = gr.Image(value="https://i.imgur.com/bgCDfT1.jpg", show_label=False, visible=True)
89
+ with gr.Row():
90
+ tile4 = gr.Image(value="https://i.imgur.com/bgCDfT1.jpg", show_label=False, visible=True)
91
+ tile5 = gr.Image(value="https://i.imgur.com/bgCDfT1.jpg", show_label=False, visible=True)
92
+ tile6 = gr.Image(value="https://i.imgur.com/bgCDfT1.jpg", show_label=False, visible=True)
93
+ with gr.Row():
94
+ tile7 = gr.Image(value="https://i.imgur.com/bgCDfT1.jpg", show_label=False, visible=True)
95
+ tile8 = gr.Image(value="https://i.imgur.com/bgCDfT1.jpg", show_label=False, visible=True)
96
+ tile9 = gr.Image(value="https://i.imgur.com/bgCDfT1.jpg", show_label=False, visible=True)
97
+
98
+ # Workaround because of the Gallery issues
99
+ tiles = [tile1, tile2, tile3, tile4, tile5, tile6, tile7, tile8, tile9]
100
+
101
+ song_option = gr.Radio(label="Songs", interactive=True, choices=None, type="value")
102
+
103
+ with gr.Column():
104
+ verse = gr.Textbox(label="Verse", placeholder="Select a song to see its lyrics")
105
+
106
+ fetch_songs.click(
107
+ fn=generate_playlist,
108
+ inputs=[song_prompt],
109
+ outputs=[song_option, *tiles],
110
+ )
111
+
112
+ example_prompts.click(
113
+ fn=set_example_prompt,
114
+ inputs=example_prompts,
115
+ outputs=example_prompts.components,
116
+ )
117
+
118
+ song_option.change(
119
+ fn=set_lyrics,
120
+ inputs=[song_option],
121
+ outputs=[verse]
122
+ )
123
+
124
+ demo.launch()