tjl223 commited on
Commit
e682173
1 Parent(s): 3d922cc

adding evaluator

Browse files
Files changed (2) hide show
  1. LyricGeneratorModel.py +3 -2
  2. app.py +21 -10
LyricGeneratorModel.py CHANGED
@@ -1,4 +1,4 @@
1
- from transformers import AutoTokenizer, AutoModelForCausalLM
2
 
3
  from peft import PeftModel, PeftConfig
4
 
@@ -6,10 +6,11 @@ from peft import PeftModel, PeftConfig
6
  class LyricGeneratorModel:
7
  def __init__(self, repo_id: str):
8
  config = PeftConfig.from_pretrained(repo_id)
 
9
  model = AutoModelForCausalLM.from_pretrained(
10
  config.base_model_name_or_path,
11
  return_dict=True,
12
- load_in_8bit=True,
13
  device_map="auto",
14
  )
15
  self.tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)
 
1
+ from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
2
 
3
  from peft import PeftModel, PeftConfig
4
 
 
6
  class LyricGeneratorModel:
7
  def __init__(self, repo_id: str):
8
  config = PeftConfig.from_pretrained(repo_id)
9
+ bnb_config = BitsAndBytesConfig(load_in_8bit=True)
10
  model = AutoModelForCausalLM.from_pretrained(
11
  config.base_model_name_or_path,
12
  return_dict=True,
13
+ quantization_config=bnb_config,
14
  device_map="auto",
15
  )
16
  self.tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)
app.py CHANGED
@@ -7,21 +7,32 @@ from LyricGeneratorModel import LyricGeneratorModel
7
  artists_df = pd.read_csv("artists.csv")
8
  artist_names_list = list(artists_df["name"])
9
 
 
 
 
 
 
 
10
 
11
- artist_name_input = st.selectbox("Artist", artist_names_list)
 
 
 
 
 
 
 
12
  song_title = st.text_input("Song Title")
13
  song_description = st.text_area("Song Description")
14
  submit_button = st.button("Submit")
15
 
16
- lyric_evaluator_model = ArtistCoherencyModel.from_pretrained(
17
- "tjl223/artist-coherency-ensemble"
18
- )
19
-
20
- lyric_generator_model = LyricGeneratorModel(
21
- "tjl223/testllama2-qlora-lyric-generator-with-description"
22
- )
23
-
24
  if submit_button:
25
- prompt = f"[Song Title] {song_title}\n[Song Artist] {artist_name_input}\n[Song Description] {song_description}"
 
 
26
  lyrics = lyric_generator_model.generate_lyrics(prompt, 1000)
 
27
  st.write(lyrics)
 
 
 
 
7
  artists_df = pd.read_csv("artists.csv")
8
  artist_names_list = list(artists_df["name"])
9
 
10
+ lyric_evaluator_model = None
11
+ with st.spinner("Loading Evaluation Model"):
12
+ lyric_evaluator_model = ArtistCoherencyModel.from_pretrained(
13
+ "tjl223/artist-coherency-ensemble"
14
+ )
15
+ st.success("Finished Loading Evaluation Model")
16
 
17
+ lyric_generator_model = None
18
+ with st.spinner("Loading Generator Model"):
19
+ lyric_generator_model = LyricGeneratorModel(
20
+ "tjl223/testllama2-qlora-lyric-generator-with-description"
21
+ )
22
+ st.success("Finished Loading Generator Model")
23
+
24
+ artist_name = st.selectbox("Artist", artist_names_list)
25
  song_title = st.text_input("Song Title")
26
  song_description = st.text_area("Song Description")
27
  submit_button = st.button("Submit")
28
 
 
 
 
 
 
 
 
 
29
  if submit_button:
30
+ prompt = f"[Song Title] {song_title}\n[Song Artist] {artist_name}\n[Song Description] {song_description}"
31
+ print(f"Prompt: {prompt}")
32
+ st.write(prompt)
33
  lyrics = lyric_generator_model.generate_lyrics(prompt, 1000)
34
+ print(f"Lyrics: {lyrics}")
35
  st.write(lyrics)
36
+ score = lyric_evaluator_model.generate_artist_coherency_score(artist_name, lyrics)
37
+ print(f"Score: {score}")
38
+ st.write(f"Score: {score}")