Spaces:
Paused
Paused
| from ArtistCoherencyModel import ArtistCoherencyModel | |
| import streamlit as st | |
| import pandas as pd | |
| from LyricGeneratorModel import LyricGeneratorModel | |
| def get_artists(): | |
| artists_df = pd.read_csv("artists.csv") | |
| return list(artists_df["name"]) | |
| def get_evaluator_model(): | |
| lyric_evaluator_model = None | |
| with st.spinner("Loading Evaluation Model..."): | |
| lyric_evaluator_model = ArtistCoherencyModel.from_pretrained( | |
| "tjl223/artist-coherency-ensemble" | |
| ) | |
| st.success("Finished Loading Evaluation Model") | |
| return lyric_evaluator_model | |
| def get_generator_model(): | |
| lyric_generator_model = None | |
| with st.spinner("Loading Generator Model..."): | |
| lyric_generator_model = LyricGeneratorModel( | |
| "tjl223/llama2-qlora-lyric-generator" | |
| ) | |
| st.success("Finished Loading Generator Model") | |
| return lyric_generator_model | |
| lyric_evaluator_model = get_evaluator_model() | |
| lyric_generator_model = get_generator_model() | |
| artist_names_list = get_artists() | |
| generator_title = st.title("Lyric Generator") | |
| artist_name_for_generator = st.selectbox("Artist", artist_names_list) | |
| song_title = st.text_input("Song Title") | |
| song_description = st.text_area("Song Description") | |
| if st.button("Submit"): | |
| prompt = f"[Song Title] {song_title}\n[Song Artist] {artist_name_for_generator}\n[Song Description] {song_description}" | |
| print(f"Prompt: {prompt}") | |
| lyrics = "" | |
| with st.spinner("Generating Lyrics..."): | |
| lyrics = lyric_generator_model.generate_lyrics(prompt, 1000) | |
| st.success("Finished Generating Lyrics") | |
| print(f"Lyrics: {lyrics}") | |
| for line in lyrics.split("\n"): | |
| if line.startswith("["): | |
| st.markdown(f"**{line}**") | |
| continue | |
| elif line.strip() == "<END_OF_SONG>": | |
| break | |
| st.write(line) | |
| evaluator_title = st.title("Lyric Evaluator") | |
| artist_name_for_evaluator = st.selectbox( | |
| "Artist", artist_names_list, key="genorator_select" | |
| ) | |
| evaluator_song_lyrics = st.text_area("Song Lyrics") | |
| if st.button("Submit", key="generator_submit"): | |
| score = lyric_evaluator_model.generate_artist_coherency_score( | |
| artist_name_for_evaluator, evaluator_song_lyrics.replace("\n\n", "\n") | |
| ) | |
| print(f"Score: {score}") | |
| st.write(f"Score: {score}") | |
| predictor_title = st.title("Lyric Predictor") | |
| predictor_song_lyrics = st.text_area("Song Lyrics", key="predictor song lyrics") | |
| if st.button("Submit", key="predictor_submit"): | |
| artist, artist_score = lyric_evaluator_model.predict_artist( | |
| predictor_song_lyrics.replace("\n\n", "\n") | |
| ) | |
| coherency, coherency_score = lyric_evaluator_model.predict_coherency( | |
| predictor_song_lyrics.replace("\n\n", "\n") | |
| ) | |
| print(f"Predicted {artist} with a score of {artist_score}") | |
| st.write(f"Predicted {artist} with a score of {artist_score}") | |
| print(f"Predicted {coherency} with a score of {coherency_score}") | |
| st.write(f"Predicted {coherency} with a score of {coherency_score}") | |