tjl223 commited on
Commit
938825c
1 Parent(s): 581e78a

fixed bug in lyric generator parsing

Browse files
Files changed (3) hide show
  1. .gitignore +2 -1
  2. LyricGeneratorModel.py +4 -1
  3. app.py +6 -4
.gitignore CHANGED
@@ -1 +1,2 @@
1
- __pycache__/
 
 
1
+ __pycache__/
2
+ test.txt
LyricGeneratorModel.py CHANGED
@@ -26,4 +26,7 @@ class LyricGeneratorModel:
26
 
27
  output_text = self.tokenizer.batch_decode(output_tokens)[0]
28
 
29
- return output_text.split(" ->: ")[1]
 
 
 
 
26
 
27
  output_text = self.tokenizer.batch_decode(output_tokens)[0]
28
 
29
+ if "->:" in output_text:
30
+ return output_text.split("->:")[1].strip()
31
+ else:
32
+ return output_text
app.py CHANGED
@@ -14,7 +14,7 @@ def get_artists():
14
  @st.cache_resource
15
  def get_evaluator_model():
16
  lyric_evaluator_model = None
17
- with st.spinner("Loading Evaluation Model"):
18
  lyric_evaluator_model = ArtistCoherencyModel.from_pretrained(
19
  "tjl223/artist-coherency-ensemble"
20
  )
@@ -25,7 +25,7 @@ def get_evaluator_model():
25
  @st.cache_resource
26
  def get_generator_model():
27
  lyric_generator_model = None
28
- with st.spinner("Loading Generator Model"):
29
  lyric_generator_model = LyricGeneratorModel(
30
  "tjl223/testllama2-qlora-lyric-generator-with-description"
31
  )
@@ -45,8 +45,10 @@ song_description = st.text_area("Song Description")
45
  if st.button("Submit"):
46
  prompt = f"[Song Title] {song_title}\n[Song Artist] {artist_name_for_generator}\n[Song Description] {song_description}"
47
  print(f"Prompt: {prompt}")
48
- st.write(prompt)
49
- lyrics = lyric_generator_model.generate_lyrics(prompt, 1000)
 
 
50
  print(f"Lyrics: {lyrics}")
51
  for line in lyrics.split("\n"):
52
  if line.startswith("["):
 
14
  @st.cache_resource
15
  def get_evaluator_model():
16
  lyric_evaluator_model = None
17
+ with st.spinner("Loading Evaluation Model..."):
18
  lyric_evaluator_model = ArtistCoherencyModel.from_pretrained(
19
  "tjl223/artist-coherency-ensemble"
20
  )
 
25
  @st.cache_resource
26
  def get_generator_model():
27
  lyric_generator_model = None
28
+ with st.spinner("Loading Generator Model..."):
29
  lyric_generator_model = LyricGeneratorModel(
30
  "tjl223/testllama2-qlora-lyric-generator-with-description"
31
  )
 
45
  if st.button("Submit"):
46
  prompt = f"[Song Title] {song_title}\n[Song Artist] {artist_name_for_generator}\n[Song Description] {song_description}"
47
  print(f"Prompt: {prompt}")
48
+ lyrics = ""
49
+ with st.spinner("Generating Lyrics..."):
50
+ lyrics = lyric_generator_model.generate_lyrics(prompt, 1000)
51
+ st.success("Finished Generating Lyrics")
52
  print(f"Lyrics: {lyrics}")
53
  for line in lyrics.split("\n"):
54
  if line.startswith("["):