File size: 2,260 Bytes
0d812a0
160d6a0
0d812a0
160d6a0
f95071d
 
eb95ef9
 
 
 
 
 
 
 
 
 
938825c
eb95ef9
 
 
 
 
 
 
 
 
 
938825c
eb95ef9
 
 
 
 
 
 
 
 
 
e682173
7a8a872
 
0e2132d
 
 
7a8a872
 
e682173
938825c
 
 
 
e682173
7a8a872
 
 
 
 
 
 
dd54a42
 
 
7a8a872
 
dd54a42
7a8a872
 
 
e682173
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
from ArtistCoherencyModel import ArtistCoherencyModel
import streamlit as st
import pandas as pd

from LyricGeneratorModel import LyricGeneratorModel


@st.cache_resource
def get_artists():
    artists_df = pd.read_csv("artists.csv")
    return list(artists_df["name"])


@st.cache_resource
def get_evaluator_model():
    lyric_evaluator_model = None
    with st.spinner("Loading Evaluation Model..."):
        lyric_evaluator_model = ArtistCoherencyModel.from_pretrained(
            "tjl223/artist-coherency-ensemble"
        )
    st.success("Finished Loading Evaluation Model")
    return lyric_evaluator_model


@st.cache_resource
def get_generator_model():
    lyric_generator_model = None
    with st.spinner("Loading Generator Model..."):
        lyric_generator_model = LyricGeneratorModel(
            "tjl223/testllama2-qlora-lyric-generator-with-description"
        )
    st.success("Finished Loading Generator Model")
    return lyric_generator_model


lyric_evaluator_model = get_evaluator_model()
lyric_generator_model = get_generator_model()
artist_names_list = get_artists()

generator_title = st.title("Lyric Generator")
artist_name_for_generator = st.selectbox("Artist", artist_names_list)
song_title = st.text_input("Song Title")
song_description = st.text_area("Song Description")

if st.button("Submit"):
    prompt = f"[Song Title] {song_title}\n[Song Artist] {artist_name_for_generator}\n[Song Description] {song_description}"
    print(f"Prompt: {prompt}")
    lyrics = ""
    with st.spinner("Generating Lyrics..."):
        lyrics = lyric_generator_model.generate_lyrics(prompt, 1000)
    st.success("Finished Generating Lyrics")
    print(f"Lyrics: {lyrics}")
    for line in lyrics.split("\n"):
        if line.startswith("["):
            st.markdown(f"**{line}**")
            continue
        st.write(line)

evaluator_title = st.title("Lyric Evaluator")
artist_name_for_evaluator = st.selectbox(
    "Artist", artist_names_list, key="genorator_select"
)
song_lyrics = st.text_area("Song Lyrics")

if st.button("Submit", key="genorator_submit"):
    score = lyric_evaluator_model.generate_artist_coherency_score(
        artist_name_for_evaluator, song_lyrics
    )
    print(f"Score: {score}")
    st.write(f"Score: {score}")