Spaces:
Running
Running
import json | |
import math | |
import random | |
import os | |
import streamlit as st | |
import lyricsgenius | |
import transformers | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
st.set_page_config(page_title="HuggingArtists") | |
st.title("HuggingArtists") | |
st.sidebar.markdown( | |
""" | |
<style> | |
.aligncenter { | |
text-align: center; | |
} | |
</style> | |
<p class="aligncenter"> | |
<img src="https://raw.githubusercontent.com/AlekseyKorshuk/huggingartists/master/img/logo.jpg" width="420" /> | |
</p> | |
""", | |
unsafe_allow_html=True, | |
) | |
st.sidebar.markdown( | |
""" | |
<style> | |
.aligncenter { | |
text-align: center; | |
} | |
</style> | |
<p style='text-align: center'> | |
<a href="https://github.com/AlekseyKorshuk/huggingartists" target="_blank">GitHub</a> | <a href="https://wandb.ai/huggingartists/huggingartists/reportlist" target="_blank">Project Report</a> | |
</p> | |
<p class="aligncenter"> | |
<a href="https://github.com/AlekseyKorshuk/huggingartists" target="_blank"> | |
<img src="https://img.shields.io/github/stars/AlekseyKorshuk/huggingartists?style=social"/> | |
</a> | |
</p> | |
<p class="aligncenter"> | |
<a href="https://colab.research.google.com/github/AlekseyKorshuk/huggingartists/blob/master/huggingartists-demo.ipynb" target="_blank"> | |
<img src="https://colab.research.google.com/assets/colab-badge.svg"/> | |
</a> | |
</p> | |
""", | |
unsafe_allow_html=True, | |
) | |
st.sidebar.header("SETTINGS") | |
num_sequences = st.sidebar.number_input( | |
"Number of sequences to generate", | |
min_value=1, | |
value=5, | |
help="The amount of generated texts", | |
) | |
min_length = st.sidebar.number_input( | |
"Minimum length of the sequence", | |
min_value=1, | |
value=100, | |
help="The minimum length of the sequence to be generated", | |
) | |
max_length= st.sidebar.number_input( | |
"Maximum length of the sequence", | |
min_value=1, | |
value=160, | |
help="The maximum length of the sequence to be generated", | |
) | |
temperature = st.sidebar.slider( | |
"Temperature", | |
min_value=0.0, | |
max_value=3.0, | |
step=0.01, | |
value=1.0, | |
help="The value used to module the next token probabilities", | |
) | |
top_p = st.sidebar.slider( | |
"Top-P", | |
min_value=0.0, | |
max_value=1.0, | |
step=0.01, | |
value=0.95, | |
help="If set to float < 1, only the most probable tokens with probabilities that add up to top_p or higher are kept for generation.", | |
) | |
top_k= st.sidebar.number_input( | |
"Top-K", | |
min_value=0, | |
value=50, | |
step=1, | |
help="The number of highest probability vocabulary tokens to keep for top-k-filtering.", | |
) | |
caption = ( | |
"In HuggingArtists, we can generate lyrics by a specific artist. This was made by fine-tuning a pre-trained HuggingFace Transformer on parsed datasets from Genius." | |
) | |
st.markdown("`HuggingArtists` - Train a model to generate lyrics π΅") | |
st.markdown(caption) | |
artist_name = st.text_input("Artist name:", "Eminem") | |
start = st.text_input("Beginning of the song:", "But for me to rap like a computer") | |
TOKEN = "q_JK_BFy9OMiG7fGTzL-nUto9JDv3iXI24aYRrQnkOvjSCSbY4BuFIindweRsr5I" | |
genius = lyricsgenius.Genius(TOKEN) | |
model_html = """ | |
<div class="inline-flex flex-col" style="line-height: 1.5;"> | |
<div class="flex"> | |
<div | |
\t\t\tstyle="display:DISPLAY_1; margin-left: auto; margin-right: auto; width: 92px; height:92px; border-radius: 50%; background-size: cover; background-image: url('USER_PROFILE')"> | |
</div> | |
</div> | |
<div style="text-align: center; margin-top: 3px; font-size: 16px; font-weight: 800">π€ HuggingArtists Model π€</div> | |
<div style="text-align: center; font-size: 16px; font-weight: 800">USER_NAME</div> | |
<a href="https://genius.com/artists/USER_HANDLE"> | |
\t<div style="text-align: center; font-size: 14px;">@USER_HANDLE</div> | |
</a> | |
</div> | |
""" | |
def post_process(output_sequences): | |
predictions = [] | |
generated_sequences = [] | |
max_repeat = 2 | |
# decode prediction | |
for generated_sequence_idx, generated_sequence in enumerate(output_sequences): | |
generated_sequence = generated_sequence.tolist() | |
text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True, skip_special_tokens=True) | |
generated_sequences.append(text.strip()) | |
for i, g in enumerate(generated_sequences): | |
res = str(g).replace('\n\n\n', '\n').replace('\n\n', '\n') | |
lines = res.split('\n') | |
# print(lines) | |
# i = max_repeat | |
# while i != len(lines): | |
# remove_count = 0 | |
# for index in range(0, max_repeat): | |
# # print(i - index - 1, i - index) | |
# if lines[i - index - 1] == lines[i - index]: | |
# remove_count += 1 | |
# if remove_count == max_repeat: | |
# lines.pop(i) | |
# i -= 1 | |
# else: | |
# i += 1 | |
predictions.append('\n'.join(lines)) | |
return predictions | |
def get_table(table_data): | |
html = ("</head>\r\n" | |
"<body>\r\n\r\n" | |
"<h2></h2>" | |
"\r\n\r\n" | |
"<table>\r\n" | |
" <colgroup>\r\n" | |
" <col span=\"1" | |
"\" style=\"width: 10" | |
"%;\">\r\n" | |
" <col span=\"1" | |
"\" style=\"width: 10" | |
"0%;\">\r\n" | |
" </colgroup>\r\n" | |
f"{' '.join(table_data)}" | |
"</table>\r\n\r\n" | |
"</body>\r\n" | |
"</html>") | |
return html | |
def get_share_button(url): | |
return f''' | |
<div style="width: 76px;"> | |
<a target="_blank" href="{url}" style='background-color:rgb(27, 149, 224);border-bottom-left-radius:4px;border-bottom-right-radius:4px;border-top-left-radius:4px;border-top-right-radius:4px;box-sizing:border-box;color:rgb(255, 255, 255);cursor:pointer;display:inline-block;font-family:"Helvetica Neue", Arial, sans-serif;font-size:13px;font-stretch:100%;font-style:normal;font-variant-caps:normal;font-variant-east-asian:normal;font-variant-ligatures:normal;font-variant-numeric:normal;font-weight:500;height:28px;line-height:26px;outline-color:rgb(255, 255, 255);outline-style:none;outline-width:0px;padding-bottom:1px;padding-left:9px;padding-right:10px;padding-top:1px;position:relative;text-align:left;text-decoration-color:rgb(255, 255, 255);text-decoration-line:none;text-decoration-style:solid;text-decoration-thickness:auto;user-select:none;vertical-align:top;white-space:nowrap;zoom:1;'> | |
<i style='background-attachment:scroll;background-clip:border-box;background-color:rgba(0,0,0,0);background-image:url(data:image/svg+xml,%3Csvg%20xmlns%3D%22http%3A%2F%2Fwww.w3.org%2F2000%2Fsvg%22%20viewBox%3D%220%200%2072%2072%22%3E%3Cpath%20fill%3D%22none%22%20d%3D%22M0%200h72v72H0z%22%2F%3E%3Cpath%20class%3D%22icon%22%20fill%3D%22%23fff%22%20d%3D%22M68.812%2015.14c-2.348%201.04-4.87%201.744-7.52%202.06%202.704-1.62%204.78-4.186%205.757-7.243-2.53%201.5-5.33%202.592-8.314%203.176C56.35%2010.59%2052.948%209%2049.182%209c-7.23%200-13.092%205.86-13.092%2013.093%200%201.026.118%202.02.338%202.98C25.543%2024.527%2015.9%2019.318%209.44%2011.396c-1.125%201.936-1.77%204.184-1.77%206.58%200%204.543%202.312%208.552%205.824%2010.9-2.146-.07-4.165-.658-5.93-1.64-.002.056-.002.11-.002.163%200%206.345%204.513%2011.638%2010.504%2012.84-1.1.298-2.256.457-3.45.457-.845%200-1.666-.078-2.464-.23%201.667%205.2%206.5%208.985%2012.23%209.09-4.482%203.51-10.13%205.605-16.26%205.605-1.055%200-2.096-.06-3.122-.184%205.794%203.717%2012.676%205.882%2020.067%205.882%2024.083%200%2037.25-19.95%2037.25-37.25%200-.565-.013-1.133-.038-1.693%202.558-1.847%204.778-4.15%206.532-6.774z%22%2F%3E%3C%2Fsvg%3E);background-origin:padding-box;background-position-x:0px;background-position-y:0px;background-repeat-x;background-repeat-y;background-size:auto;color:rgb(255,255,255);cursor:pointer;display:inline-block;font-family:"Helvetica Neue",Arial,sans-serif;font-size:13px;font-stretch:100%;font-style:italic;font-variant-caps:normal;font-variant-east-asian:normal;font-variant-ligatures:normal;font-variant-numeric:normal;font-weight:500;height:18px;line-height:26px;position:relative;text-align:left;text-decoration-thickness:auto;top:4px;user-select:none;white-space:nowrap;width:18px;'></i> | |
<span style='color:rgb(255,255,255);cursor:pointer;display:inline-block;font-family:"Helvetica Neue",Arial,sans-serif;font-size:13px;font-stretch:100%;font-style:normal;font-variant-caps:normal;font-variant-east-asian:normal;font-variant-ligatures:normal;font-variant-numeric:normal;font-weight:500;line-height:26px;margin-left:4px;text-align:left;text-decoration-thickness:auto;user-select:none;vertical-align:top;white-space:nowrap;zoom:1;'>Tweet</span> | |
</a> | |
</div> | |
''' | |
def share_model_table(artist_name, model_name): | |
url = f"https://twitter.com/intent/tweet?text=I created an AI bot of {artist_name} with %23huggingartists!%0APlay with my model or create your own! &url=https://huggingface.co/huggingartists/{model_name}" | |
share_button = get_share_button(url) | |
table_data = [ | |
f'<tr><td>{share_button}</td><td>π Share {artist_name} model: <a href="https://huggingface.co/huggingartists/{model_name}">https://huggingface.co/huggingartists/{model_name}</a></td></tr>' | |
] | |
return get_table(table_data) | |
def get_share_lyrics_url(artist_name, model_name, lyrics): | |
return "https://twitter.com/intent/tweet?text=I created an AI bot of " + artist_name + " with %23huggingartists!%0A%0ABrand new song:%0A" + lyrics.replace('\n', '%0A').replace('"', '%22') + "%0A%0APlay with my model or create your own! &url=https://huggingface.co/huggingartists/" + model_name | |
if st.button("Run"): | |
model_name = None | |
with st.spinner(text=f"Searching for {artist_name } in Genius..."): | |
artist = genius.search_artist(artist_name, max_songs=0, get_full_info=False) | |
if artist is not None: | |
artist_dict = genius.artist(artist.id)['artist'] | |
artist_url = str(artist_dict['url']) | |
model_name = artist_url[artist_url.rfind('/') + 1:].lower() | |
st.markdown(model_html.replace("USER_PROFILE",artist.image_url).replace("USER_NAME",artist.name).replace("USER_HANDLE",model_name), unsafe_allow_html=True) | |
else: | |
st.markdown(f"Could not find {artist_name}! Be sure that he/she exists in [Genius](https://genius.com/).") | |
if model_name is not None: | |
with st.spinner(text=f"Downloading the model of {artist_name }..."): | |
model = None | |
tokenizer = None | |
try: | |
tokenizer = AutoTokenizer.from_pretrained(f"huggingartists/{model_name}") | |
model = AutoModelForCausalLM.from_pretrained(f"huggingartists/{model_name}") | |
except Exception as ex: | |
st.markdown(ex) | |
st.markdown(f"Model for this artist does not exist yet. Create it in just 5 min with [Colab Notebook](https://colab.research.google.com/github/AlekseyKorshuk/huggingartists/blob/master/huggingartists-demo.ipynb):") | |
st.markdown( | |
""" | |
<style> | |
.aligncenter { | |
text-align: center; | |
} | |
</style> | |
<p class="aligncenter"> | |
<a href="https://colab.research.google.com/github/AlekseyKorshuk/huggingartists/blob/master/huggingartists-demo.ipynb" target="_blank"> | |
<img src="https://colab.research.google.com/assets/colab-badge.svg"/> | |
</a> | |
</p> | |
""", | |
unsafe_allow_html=True, | |
) | |
if model is not None: | |
with st.spinner(text=f"Generating lyrics..."): | |
encoded_prompt = tokenizer(start, add_special_tokens=False, return_tensors="pt").input_ids | |
encoded_prompt = encoded_prompt.to(model.device) | |
# prediction | |
output_sequences = model.generate( | |
input_ids=encoded_prompt, | |
max_length=max_length, | |
min_length=min_length, | |
temperature=float(temperature), | |
top_p=float(top_p), | |
top_k=int(top_k), | |
do_sample=True, | |
repetition_penalty=1.0, | |
num_return_sequences=num_sequences | |
) | |
# Post-processing | |
predictions = post_process(output_sequences) | |
st.subheader("Result") | |
for prediction in predictions: | |
st.text(prediction) | |