tjl223 commited on
Commit
eb95ef9
1 Parent(s): e682173

cache models and data

Browse files
Files changed (1) hide show
  1. app.py +32 -16
app.py CHANGED
@@ -4,22 +4,38 @@ import pandas as pd
4
 
5
  from LyricGeneratorModel import LyricGeneratorModel
6
 
7
- artists_df = pd.read_csv("artists.csv")
8
- artist_names_list = list(artists_df["name"])
9
-
10
- lyric_evaluator_model = None
11
- with st.spinner("Loading Evaluation Model"):
12
- lyric_evaluator_model = ArtistCoherencyModel.from_pretrained(
13
- "tjl223/artist-coherency-ensemble"
14
- )
15
- st.success("Finished Loading Evaluation Model")
16
-
17
- lyric_generator_model = None
18
- with st.spinner("Loading Generator Model"):
19
- lyric_generator_model = LyricGeneratorModel(
20
- "tjl223/testllama2-qlora-lyric-generator-with-description"
21
- )
22
- st.success("Finished Loading Generator Model")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
  artist_name = st.selectbox("Artist", artist_names_list)
25
  song_title = st.text_input("Song Title")
 
4
 
5
  from LyricGeneratorModel import LyricGeneratorModel
6
 
7
+
8
+ @st.cache_resource
9
+ def get_artists():
10
+ artists_df = pd.read_csv("artists.csv")
11
+ return list(artists_df["name"])
12
+
13
+
14
+ @st.cache_resource
15
+ def get_evaluator_model():
16
+ lyric_evaluator_model = None
17
+ with st.spinner("Loading Evaluation Model"):
18
+ lyric_evaluator_model = ArtistCoherencyModel.from_pretrained(
19
+ "tjl223/artist-coherency-ensemble"
20
+ )
21
+ st.success("Finished Loading Evaluation Model")
22
+ return lyric_evaluator_model
23
+
24
+
25
+ @st.cache_resource
26
+ def get_generator_model():
27
+ lyric_generator_model = None
28
+ with st.spinner("Loading Generator Model"):
29
+ lyric_generator_model = LyricGeneratorModel(
30
+ "tjl223/testllama2-qlora-lyric-generator-with-description"
31
+ )
32
+ st.success("Finished Loading Generator Model")
33
+ return lyric_generator_model
34
+
35
+
36
+ lyric_evaluator_model = get_evaluator_model()
37
+ lyric_generator_model = get_generator_model()
38
+ artist_names_list = get_artists()
39
 
40
  artist_name = st.selectbox("Artist", artist_names_list)
41
  song_title = st.text_input("Song Title")