from ArtistCoherencyModel import ArtistCoherencyModel import streamlit as st import pandas as pd from LyricGeneratorModel import LyricGeneratorModel @st.cache_resource def get_artists(): artists_df = pd.read_csv("artists.csv") return list(artists_df["name"]) @st.cache_resource def get_evaluator_model(): lyric_evaluator_model = None with st.spinner("Loading Evaluation Model"): lyric_evaluator_model = ArtistCoherencyModel.from_pretrained( "tjl223/artist-coherency-ensemble" ) st.success("Finished Loading Evaluation Model") return lyric_evaluator_model @st.cache_resource def get_generator_model(): lyric_generator_model = None with st.spinner("Loading Generator Model"): lyric_generator_model = LyricGeneratorModel( "tjl223/testllama2-qlora-lyric-generator-with-description" ) st.success("Finished Loading Generator Model") return lyric_generator_model lyric_evaluator_model = get_evaluator_model() lyric_generator_model = get_generator_model() artist_names_list = get_artists() artist_name = st.selectbox("Artist", artist_names_list) song_title = st.text_input("Song Title") song_description = st.text_area("Song Description") submit_button = st.button("Submit") if submit_button: prompt = f"[Song Title] {song_title}\n[Song Artist] {artist_name}\n[Song Description] {song_description}" print(f"Prompt: {prompt}") st.write(prompt) lyrics = lyric_generator_model.generate_lyrics(prompt, 1000) print(f"Lyrics: {lyrics}") st.write(lyrics) score = lyric_evaluator_model.generate_artist_coherency_score(artist_name, lyrics) print(f"Score: {score}") st.write(f"Score: {score}")