Spaces:
Paused
Paused
File size: 2,393 Bytes
0d812a0 160d6a0 0d812a0 160d6a0 f95071d eb95ef9 938825c eb95ef9 938825c eb95ef9 91d4c63 eb95ef9 e682173 7a8a872 0e2132d 7a8a872 e682173 938825c e682173 7a8a872 4e5a7c4 7a8a872 dd54a42 7a8a872 dd54a42 7a8a872 4de628b 7a8a872 4de628b e682173 4de628b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 |
from ArtistCoherencyModel import ArtistCoherencyModel
import streamlit as st
import pandas as pd
from LyricGeneratorModel import LyricGeneratorModel
@st.cache_resource
def get_artists():
artists_df = pd.read_csv("artists.csv")
return list(artists_df["name"])
@st.cache_resource
def get_evaluator_model():
lyric_evaluator_model = None
with st.spinner("Loading Evaluation Model..."):
lyric_evaluator_model = ArtistCoherencyModel.from_pretrained(
"tjl223/artist-coherency-ensemble"
)
st.success("Finished Loading Evaluation Model")
return lyric_evaluator_model
@st.cache_resource
def get_generator_model():
lyric_generator_model = None
with st.spinner("Loading Generator Model..."):
lyric_generator_model = LyricGeneratorModel(
"tjl223/llama2-qlora-lyric-generator"
)
st.success("Finished Loading Generator Model")
return lyric_generator_model
lyric_evaluator_model = get_evaluator_model()
lyric_generator_model = get_generator_model()
artist_names_list = get_artists()
generator_title = st.title("Lyric Generator")
artist_name_for_generator = st.selectbox("Artist", artist_names_list)
song_title = st.text_input("Song Title")
song_description = st.text_area("Song Description")
if st.button("Submit"):
prompt = f"[Song Title] {song_title}\n[Song Artist] {artist_name_for_generator}\n[Song Description] {song_description}"
print(f"Prompt: {prompt}")
lyrics = ""
with st.spinner("Generating Lyrics..."):
lyrics = lyric_generator_model.generate_lyrics(prompt, 1000)
st.success("Finished Generating Lyrics")
print(f"Lyrics: {lyrics}")
for line in lyrics.split("\n"):
if line.startswith("["):
st.markdown(f"**{line}**")
continue
elif line.strip() == "<END_OF_SONG>":
break
st.write(line)
evaluator_title = st.title("Lyric Evaluator")
artist_name_for_evaluator = st.selectbox(
"Artist", artist_names_list, key="genorator_select"
)
song_lyrics = st.text_area("Song Lyrics")
if st.button("Submit", key="genorator_submit"):
score = lyric_evaluator_model.generate_artist_coherency_score(
artist_name_for_evaluator, song_lyrics.replace("\n\n", "\n")
)
print(song_lyrics)
print(f"Score: {score}")
st.write(f"Score: {score}")
print(song_lyrics.replace("\n\n", "\n"))
|