CS-4700-Demo / app.py
tjl223's picture
fixed bug in lyric generator parsing
938825c
raw history blame
No virus
2.26 kB
from ArtistCoherencyModel import ArtistCoherencyModel
import streamlit as st
import pandas as pd
from LyricGeneratorModel import LyricGeneratorModel
@st.cache_resource
def get_artists():
artists_df = pd.read_csv("artists.csv")
return list(artists_df["name"])
@st.cache_resource
def get_evaluator_model():
lyric_evaluator_model = None
with st.spinner("Loading Evaluation Model..."):
lyric_evaluator_model = ArtistCoherencyModel.from_pretrained(
"tjl223/artist-coherency-ensemble"
)
st.success("Finished Loading Evaluation Model")
return lyric_evaluator_model
@st.cache_resource
def get_generator_model():
lyric_generator_model = None
with st.spinner("Loading Generator Model..."):
lyric_generator_model = LyricGeneratorModel(
"tjl223/testllama2-qlora-lyric-generator-with-description"
)
st.success("Finished Loading Generator Model")
return lyric_generator_model
lyric_evaluator_model = get_evaluator_model()
lyric_generator_model = get_generator_model()
artist_names_list = get_artists()
generator_title = st.title("Lyric Generator")
artist_name_for_generator = st.selectbox("Artist", artist_names_list)
song_title = st.text_input("Song Title")
song_description = st.text_area("Song Description")
if st.button("Submit"):
prompt = f"[Song Title] {song_title}\n[Song Artist] {artist_name_for_generator}\n[Song Description] {song_description}"
print(f"Prompt: {prompt}")
lyrics = ""
with st.spinner("Generating Lyrics..."):
lyrics = lyric_generator_model.generate_lyrics(prompt, 1000)
st.success("Finished Generating Lyrics")
print(f"Lyrics: {lyrics}")
for line in lyrics.split("\n"):
if line.startswith("["):
st.markdown(f"**{line}**")
continue
st.write(line)
evaluator_title = st.title("Lyric Evaluator")
artist_name_for_evaluator = st.selectbox(
"Artist", artist_names_list, key="genorator_select"
)
song_lyrics = st.text_area("Song Lyrics")
if st.button("Submit", key="genorator_submit"):
score = lyric_evaluator_model.generate_artist_coherency_score(
artist_name_for_evaluator, song_lyrics
)
print(f"Score: {score}")
st.write(f"Score: {score}")