Spaces:
Paused
Paused
from ArtistCoherencyModel import ArtistCoherencyModel | |
import streamlit as st | |
import pandas as pd | |
from LyricGeneratorModel import LyricGeneratorModel | |
def get_artists(): | |
artists_df = pd.read_csv("artists.csv") | |
return list(artists_df["name"]) | |
def get_evaluator_model(): | |
lyric_evaluator_model = None | |
with st.spinner("Loading Evaluation Model..."): | |
lyric_evaluator_model = ArtistCoherencyModel.from_pretrained( | |
"tjl223/artist-coherency-ensemble" | |
) | |
st.success("Finished Loading Evaluation Model") | |
return lyric_evaluator_model | |
def get_generator_model(): | |
lyric_generator_model = None | |
with st.spinner("Loading Generator Model..."): | |
lyric_generator_model = LyricGeneratorModel( | |
"tjl223/llama2-qlora-lyric-generator" | |
) | |
st.success("Finished Loading Generator Model") | |
return lyric_generator_model | |
lyric_evaluator_model = get_evaluator_model() | |
lyric_generator_model = get_generator_model() | |
artist_names_list = get_artists() | |
generator_title = st.title("Lyric Generator") | |
artist_name_for_generator = st.selectbox("Artist", artist_names_list) | |
song_title = st.text_input("Song Title") | |
song_description = st.text_area("Song Description") | |
if st.button("Submit"): | |
prompt = f"[Song Title] {song_title}\n[Song Artist] {artist_name_for_generator}\n[Song Description] {song_description}" | |
print(f"Prompt: {prompt}") | |
lyrics = "" | |
with st.spinner("Generating Lyrics..."): | |
lyrics = lyric_generator_model.generate_lyrics(prompt, 1000) | |
st.success("Finished Generating Lyrics") | |
print(f"Lyrics: {lyrics}") | |
for line in lyrics.split("\n"): | |
if line.startswith("["): | |
st.markdown(f"**{line}**") | |
continue | |
elif line.strip() == "<END_OF_SONG>": | |
break | |
st.write(line) | |
evaluator_title = st.title("Lyric Evaluator") | |
artist_name_for_evaluator = st.selectbox( | |
"Artist", artist_names_list, key="genorator_select" | |
) | |
evaluator_song_lyrics = st.text_area("Song Lyrics") | |
if st.button("Submit", key="generator_submit"): | |
score = lyric_evaluator_model.generate_artist_coherency_score( | |
artist_name_for_evaluator, evaluator_song_lyrics.replace("\n\n", "\n") | |
) | |
print(f"Score: {score}") | |
st.write(f"Score: {score}") | |
predictor_title = st.title("Lyric Predictor") | |
predictor_song_lyrics = st.text_area("Song Lyrics", key="predictor song lyrics") | |
if st.button("Submit", key="predictor_submit"): | |
artist, artist_score = lyric_evaluator_model.predict_artist( | |
predictor_song_lyrics.replace("\n\n", "\n") | |
) | |
coherency, coherency_score = lyric_evaluator_model.predict_coherency( | |
predictor_song_lyrics.replace("\n\n", "\n") | |
) | |
print(f"Predicted {artist} with a score of {artist_score}") | |
st.write(f"Predicted {artist} with a score of {artist_score}") | |
print(f"Predicted {coherency} with a score of {coherency_score}") | |
st.write(f"Predicted {coherency} with a score of {coherency_score}") | |