Spaces:
Running
Running
AlekseyKorshuk
commited on
Commit
•
5739196
1
Parent(s):
df81345
Update app.py
Browse files
app.py
CHANGED
@@ -4,6 +4,9 @@ import random
|
|
4 |
import os
|
5 |
import streamlit as st
|
6 |
import lyricsgenius
|
|
|
|
|
|
|
7 |
|
8 |
|
9 |
st.set_page_config(page_title="HuggingArtists")
|
@@ -122,11 +125,140 @@ model_html = """
|
|
122 |
</div>
|
123 |
"""
|
124 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
125 |
if st.button("Run"):
|
126 |
-
|
|
|
127 |
artist = genius.search_artist(artist_name, max_songs=0, get_full_info=False)
|
128 |
if artist is not None:
|
129 |
artist_dict = genius.artist(artist.id)['artist']
|
130 |
artist_url = str(artist_dict['url'])
|
131 |
model_name = artist_url[artist_url.rfind('/') + 1:].lower()
|
132 |
st.markdown(model_html.replace("USER_PROFILE",artist.image_url).replace("USER_NAME",artist.name).replace("USER_HANDLE",model_name), unsafe_allow_html=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
import os
|
5 |
import streamlit as st
|
6 |
import lyricsgenius
|
7 |
+
import transformers
|
8 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
9 |
+
|
10 |
|
11 |
|
12 |
st.set_page_config(page_title="HuggingArtists")
|
|
|
125 |
</div>
|
126 |
"""
|
127 |
|
128 |
+
|
129 |
+
def post_process(output_sequences):
|
130 |
+
predictions = []
|
131 |
+
generated_sequences = []
|
132 |
+
|
133 |
+
max_repeat = 2
|
134 |
+
|
135 |
+
# decode prediction
|
136 |
+
for generated_sequence_idx, generated_sequence in enumerate(output_sequences):
|
137 |
+
generated_sequence = generated_sequence.tolist()
|
138 |
+
text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True, skip_special_tokens=True)
|
139 |
+
generated_sequences.append(text.strip())
|
140 |
+
|
141 |
+
for i, g in enumerate(generated_sequences):
|
142 |
+
res = str(g).replace('\n\n\n', '\n').replace('\n\n', '\n')
|
143 |
+
lines = res.split('\n')
|
144 |
+
# print(lines)
|
145 |
+
# i = max_repeat
|
146 |
+
# while i != len(lines):
|
147 |
+
# remove_count = 0
|
148 |
+
# for index in range(0, max_repeat):
|
149 |
+
# # print(i - index - 1, i - index)
|
150 |
+
# if lines[i - index - 1] == lines[i - index]:
|
151 |
+
# remove_count += 1
|
152 |
+
# if remove_count == max_repeat:
|
153 |
+
# lines.pop(i)
|
154 |
+
# i -= 1
|
155 |
+
# else:
|
156 |
+
# i += 1
|
157 |
+
predictions.append('\n'.join(lines))
|
158 |
+
|
159 |
+
return predictions
|
160 |
+
|
161 |
+
def get_table(table_data):
|
162 |
+
html = ("</head>\r\n"
|
163 |
+
"<body>\r\n\r\n"
|
164 |
+
"<h2></h2>"
|
165 |
+
"\r\n\r\n"
|
166 |
+
"<table>\r\n"
|
167 |
+
" <colgroup>\r\n"
|
168 |
+
" <col span=\"1"
|
169 |
+
"\" style=\"width: 10"
|
170 |
+
"%;\">\r\n"
|
171 |
+
" <col span=\"1"
|
172 |
+
"\" style=\"width: 10"
|
173 |
+
"0%;\">\r\n"
|
174 |
+
" </colgroup>\r\n"
|
175 |
+
f"{' '.join(table_data)}"
|
176 |
+
"</table>\r\n\r\n"
|
177 |
+
"</body>\r\n"
|
178 |
+
"</html>")
|
179 |
+
|
180 |
+
return html
|
181 |
+
|
182 |
+
def get_share_button(url):
|
183 |
+
return f'''
|
184 |
+
<div style="width: 76px;">
|
185 |
+
<a target="_blank" href="{url}" style='background-color:rgb(27, 149, 224);border-bottom-left-radius:4px;border-bottom-right-radius:4px;border-top-left-radius:4px;border-top-right-radius:4px;box-sizing:border-box;color:rgb(255, 255, 255);cursor:pointer;display:inline-block;font-family:"Helvetica Neue", Arial, sans-serif;font-size:13px;font-stretch:100%;font-style:normal;font-variant-caps:normal;font-variant-east-asian:normal;font-variant-ligatures:normal;font-variant-numeric:normal;font-weight:500;height:28px;line-height:26px;outline-color:rgb(255, 255, 255);outline-style:none;outline-width:0px;padding-bottom:1px;padding-left:9px;padding-right:10px;padding-top:1px;position:relative;text-align:left;text-decoration-color:rgb(255, 255, 255);text-decoration-line:none;text-decoration-style:solid;text-decoration-thickness:auto;user-select:none;vertical-align:top;white-space:nowrap;zoom:1;'>
|
186 |
+
<i style='background-attachment:scroll;background-clip:border-box;background-color:rgba(0,0,0,0);background-image:url(data:image/svg+xml,%3Csvg%20xmlns%3D%22http%3A%2F%2Fwww.w3.org%2F2000%2Fsvg%22%20viewBox%3D%220%200%2072%2072%22%3E%3Cpath%20fill%3D%22none%22%20d%3D%22M0%200h72v72H0z%22%2F%3E%3Cpath%20class%3D%22icon%22%20fill%3D%22%23fff%22%20d%3D%22M68.812%2015.14c-2.348%201.04-4.87%201.744-7.52%202.06%202.704-1.62%204.78-4.186%205.757-7.243-2.53%201.5-5.33%202.592-8.314%203.176C56.35%2010.59%2052.948%209%2049.182%209c-7.23%200-13.092%205.86-13.092%2013.093%200%201.026.118%202.02.338%202.98C25.543%2024.527%2015.9%2019.318%209.44%2011.396c-1.125%201.936-1.77%204.184-1.77%206.58%200%204.543%202.312%208.552%205.824%2010.9-2.146-.07-4.165-.658-5.93-1.64-.002.056-.002.11-.002.163%200%206.345%204.513%2011.638%2010.504%2012.84-1.1.298-2.256.457-3.45.457-.845%200-1.666-.078-2.464-.23%201.667%205.2%206.5%208.985%2012.23%209.09-4.482%203.51-10.13%205.605-16.26%205.605-1.055%200-2.096-.06-3.122-.184%205.794%203.717%2012.676%205.882%2020.067%205.882%2024.083%200%2037.25-19.95%2037.25-37.25%200-.565-.013-1.133-.038-1.693%202.558-1.847%204.778-4.15%206.532-6.774z%22%2F%3E%3C%2Fsvg%3E);background-origin:padding-box;background-position-x:0px;background-position-y:0px;background-repeat-x;background-repeat-y;background-size:auto;color:rgb(255,255,255);cursor:pointer;display:inline-block;font-family:"Helvetica Neue",Arial,sans-serif;font-size:13px;font-stretch:100%;font-style:italic;font-variant-caps:normal;font-variant-east-asian:normal;font-variant-ligatures:normal;font-variant-numeric:normal;font-weight:500;height:18px;line-height:26px;position:relative;text-align:left;text-decoration-thickness:auto;top:4px;user-select:none;white-space:nowrap;width:18px;'></i>
|
187 |
+
<span style='color:rgb(255,255,255);cursor:pointer;display:inline-block;font-family:"Helvetica Neue",Arial,sans-serif;font-size:13px;font-stretch:100%;font-style:normal;font-variant-caps:normal;font-variant-east-asian:normal;font-variant-ligatures:normal;font-variant-numeric:normal;font-weight:500;line-height:26px;margin-left:4px;text-align:left;text-decoration-thickness:auto;user-select:none;vertical-align:top;white-space:nowrap;zoom:1;'>Tweet</span>
|
188 |
+
</a>
|
189 |
+
</div>
|
190 |
+
'''
|
191 |
+
|
192 |
+
def share_model_table(artist_name, model_name):
|
193 |
+
url = f"https://twitter.com/intent/tweet?text=I created an AI bot of {artist_name} with %23huggingartists!%0APlay with my model or create your own! &url=https://huggingface.co/huggingartists/{model_name}"
|
194 |
+
|
195 |
+
share_button = get_share_button(url)
|
196 |
+
table_data = [
|
197 |
+
f'<tr><td>{share_button}</td><td>🎉 Share {artist_name} model: <a href="https://huggingface.co/huggingartists/{model_name}">https://huggingface.co/huggingartists/{model_name}</a></td></tr>'
|
198 |
+
]
|
199 |
+
return get_table(table_data)
|
200 |
+
|
201 |
+
def get_share_lyrics_url(artist_name, model_name, lyrics):
|
202 |
+
return "https://twitter.com/intent/tweet?text=I created an AI bot of " + artist_name + " with %23huggingartists!%0A%0ABrand new song:%0A" + lyrics.replace('\n', '%0A').replace('"', '%22') + "%0A%0APlay with my model or create your own! &url=https://huggingface.co/huggingartists/" + model_name
|
203 |
+
|
204 |
if st.button("Run"):
|
205 |
+
model_name = None
|
206 |
+
with st.spinner(text=f"Searching for {artist_name } in Genius..."):
|
207 |
artist = genius.search_artist(artist_name, max_songs=0, get_full_info=False)
|
208 |
if artist is not None:
|
209 |
artist_dict = genius.artist(artist.id)['artist']
|
210 |
artist_url = str(artist_dict['url'])
|
211 |
model_name = artist_url[artist_url.rfind('/') + 1:].lower()
|
212 |
st.markdown(model_html.replace("USER_PROFILE",artist.image_url).replace("USER_NAME",artist.name).replace("USER_HANDLE",model_name), unsafe_allow_html=True)
|
213 |
+
else:
|
214 |
+
st.markdown(f"Could not find {artist_name}! Be sure that he/she exists in [Genius](https://genius.com/).")
|
215 |
+
if model_name is not None:
|
216 |
+
with st.spinner(text=f"Downloading the model of {artist_name }..."):
|
217 |
+
model = None
|
218 |
+
tokenizer = None
|
219 |
+
try:
|
220 |
+
tokenizer = AutoTokenizer.from_pretrained(f"huggingartists/{model_name}")
|
221 |
+
model = AutoModelForCausalLM.from_pretrained(f"huggingartists/{model_name}")
|
222 |
+
except:
|
223 |
+
st.markdown(f"Model for this artist does not exist yet. Create it in just 5 min with [Colab Notebook](https://colab.research.google.com/github/AlekseyKorshuk/huggingartists/blob/master/huggingartists-demo.ipynb):")
|
224 |
+
st.markdown(
|
225 |
+
"""
|
226 |
+
<style>
|
227 |
+
.aligncenter {
|
228 |
+
text-align: center;
|
229 |
+
}
|
230 |
+
</style>
|
231 |
+
<p class="aligncenter">
|
232 |
+
<a href="https://colab.research.google.com/github/AlekseyKorshuk/huggingartists/blob/master/huggingartists-demo.ipynb" target="_blank">
|
233 |
+
<img src="https://colab.research.google.com/assets/colab-badge.svg"/>
|
234 |
+
</a>
|
235 |
+
</p>
|
236 |
+
""",
|
237 |
+
unsafe_allow_html=True,
|
238 |
+
)
|
239 |
+
if model is not None:
|
240 |
+
with st.spinner(text=f"Generating lyrics..."):
|
241 |
+
encoded_prompt = tokenizer(start, add_special_tokens=False, return_tensors="pt").input_ids
|
242 |
+
encoded_prompt = encoded_prompt.to(trainer.model.device)
|
243 |
+
# prediction
|
244 |
+
output_sequences = trainer.model.generate(
|
245 |
+
input_ids=encoded_prompt,
|
246 |
+
max_length=max_length,
|
247 |
+
min_length=min_length,
|
248 |
+
temperature=float(temperature),
|
249 |
+
top_p=float(top_p),
|
250 |
+
top_k=int(top_k),
|
251 |
+
do_sample=True,
|
252 |
+
repetition_penalty=1.0,
|
253 |
+
num_return_sequences=num_sequences
|
254 |
+
)
|
255 |
+
# Post-processing
|
256 |
+
predictions = post_process(output_sequences)
|
257 |
+
table_data = []
|
258 |
+
for result in predictions:
|
259 |
+
table_data.append('<tr><td>' + get_share_button(get_share_lyrics_url(artist.name, model_name, result)) + '</td><td>' + result.replace("\n", "<br>") + '</td></tr>')
|
260 |
+
st.markdown(share_model_table(artist.name, model_name),
|
261 |
+
unsafe_allow_html=True)
|
262 |
+
st.markdown(get_table(table_data),
|
263 |
+
unsafe_allow_html=True)
|
264 |
+
|