File size: 3,051 Bytes
0d812a0
160d6a0
0d812a0
160d6a0
f95071d
 
eb95ef9
 
 
 
 
 
 
 
 
 
938825c
eb95ef9
 
 
 
 
 
 
 
 
 
938825c
eb95ef9
91d4c63
eb95ef9
 
 
 
 
 
 
 
e682173
7a8a872
 
0e2132d
 
 
7a8a872
 
e682173
938825c
 
 
 
e682173
7a8a872
 
 
 
4e5a7c4
 
7a8a872
 
 
dd54a42
 
 
4181b0f
7a8a872
4181b0f
7a8a872
4181b0f
7a8a872
e682173
 
4181b0f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
from ArtistCoherencyModel import ArtistCoherencyModel
import streamlit as st
import pandas as pd

from LyricGeneratorModel import LyricGeneratorModel


@st.cache_resource
def get_artists():
    artists_df = pd.read_csv("artists.csv")
    return list(artists_df["name"])


@st.cache_resource
def get_evaluator_model():
    lyric_evaluator_model = None
    with st.spinner("Loading Evaluation Model..."):
        lyric_evaluator_model = ArtistCoherencyModel.from_pretrained(
            "tjl223/artist-coherency-ensemble"
        )
    st.success("Finished Loading Evaluation Model")
    return lyric_evaluator_model


@st.cache_resource
def get_generator_model():
    lyric_generator_model = None
    with st.spinner("Loading Generator Model..."):
        lyric_generator_model = LyricGeneratorModel(
            "tjl223/llama2-qlora-lyric-generator"
        )
    st.success("Finished Loading Generator Model")
    return lyric_generator_model


lyric_evaluator_model = get_evaluator_model()
lyric_generator_model = get_generator_model()
artist_names_list = get_artists()

generator_title = st.title("Lyric Generator")
artist_name_for_generator = st.selectbox("Artist", artist_names_list)
song_title = st.text_input("Song Title")
song_description = st.text_area("Song Description")

if st.button("Submit"):
    prompt = f"[Song Title] {song_title}\n[Song Artist] {artist_name_for_generator}\n[Song Description] {song_description}"
    print(f"Prompt: {prompt}")
    lyrics = ""
    with st.spinner("Generating Lyrics..."):
        lyrics = lyric_generator_model.generate_lyrics(prompt, 1000)
    st.success("Finished Generating Lyrics")
    print(f"Lyrics: {lyrics}")
    for line in lyrics.split("\n"):
        if line.startswith("["):
            st.markdown(f"**{line}**")
            continue
        elif line.strip() == "<END_OF_SONG>":
            break
        st.write(line)

evaluator_title = st.title("Lyric Evaluator")
artist_name_for_evaluator = st.selectbox(
    "Artist", artist_names_list, key="genorator_select"
)
evaluator_song_lyrics = st.text_area("Song Lyrics")

if st.button("Submit", key="generator_submit"):
    score = lyric_evaluator_model.generate_artist_coherency_score(
        artist_name_for_evaluator, evaluator_song_lyrics.replace("\n\n", "\n")
    )
    print(f"Score: {score}")
    st.write(f"Score: {score}")

predictor_title = st.title("Lyric Predictor")
predictor_song_lyrics = st.text_area("Song Lyrics", key="predictor song lyrics")

if st.button("Submit", key="predictor_submit"):
    artist, artist_score = lyric_evaluator_model.predict_artist(
        predictor_song_lyrics.replace("\n\n", "\n")
    )
    coherency, coherency_score = lyric_evaluator_model.predict_coherency(
        predictor_song_lyrics.replace("\n\n", "\n")
    )
    print(f"Predicted {artist} with a score of {artist_score}")
    st.write(f"Predicted {artist} with a score of {artist_score}")
    print(f"Predicted {coherency} with a score of {coherency_score}")
    st.write(f"Predicted {coherency} with a score of {coherency_score}")