from ArtistCoherencyModel import ArtistCoherencyModel import streamlit as st import pandas as pd from LyricGeneratorModel import LyricGeneratorModel @st.cache_resource def get_artists(): artists_df = pd.read_csv("artists.csv") return list(artists_df["name"]) @st.cache_resource def get_evaluator_model(): lyric_evaluator_model = None with st.spinner("Loading Evaluation Model..."): lyric_evaluator_model = ArtistCoherencyModel.from_pretrained( "tjl223/artist-coherency-ensemble" ) st.success("Finished Loading Evaluation Model") return lyric_evaluator_model @st.cache_resource def get_generator_model(): lyric_generator_model = None with st.spinner("Loading Generator Model..."): lyric_generator_model = LyricGeneratorModel( "tjl223/llama2-qlora-lyric-generator" ) st.success("Finished Loading Generator Model") return lyric_generator_model lyric_evaluator_model = get_evaluator_model() lyric_generator_model = get_generator_model() artist_names_list = get_artists() generator_title = st.title("Lyric Generator") artist_name_for_generator = st.selectbox("Artist", artist_names_list) song_title = st.text_input("Song Title") song_description = st.text_area("Song Description") if st.button("Submit"): prompt = f"[Song Title] {song_title}\n[Song Artist] {artist_name_for_generator}\n[Song Description] {song_description}" print(f"Prompt: {prompt}") lyrics = "" with st.spinner("Generating Lyrics..."): lyrics = lyric_generator_model.generate_lyrics(prompt, 1000) st.success("Finished Generating Lyrics") print(f"Lyrics: {lyrics}") for line in lyrics.split("\n"): if line.startswith("["): st.markdown(f"**{line}**") continue elif line.strip() == "": break st.write(line) evaluator_title = st.title("Lyric Evaluator") artist_name_for_evaluator = st.selectbox( "Artist", artist_names_list, key="genorator_select" ) evaluator_song_lyrics = st.text_area("Song Lyrics") if st.button("Submit", key="generator_submit"): score = lyric_evaluator_model.generate_artist_coherency_score( artist_name_for_evaluator, evaluator_song_lyrics.replace("\n\n", "\n") ) print(f"Score: {score}") st.write(f"Score: {score}") predictor_title = st.title("Lyric Predictor") predictor_song_lyrics = st.text_area("Song Lyrics", key="predictor song lyrics") if st.button("Submit", key="predictor_submit"): artist, artist_score = lyric_evaluator_model.predict_artist( predictor_song_lyrics.replace("\n\n", "\n") ) coherency, coherency_score = lyric_evaluator_model.predict_coherency( predictor_song_lyrics.replace("\n\n", "\n") ) print(f"Predicted {artist} with a score of {artist_score}") st.write(f"Predicted {artist} with a score of {artist_score}") print(f"Predicted {coherency} with a score of {coherency_score}") st.write(f"Predicted {coherency} with a score of {coherency_score}")