Cropinky commited on
Commit
b4d36a1
1 Parent(s): 45dc73c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -4
app.py CHANGED
@@ -3,20 +3,20 @@ from streamlit.elements.altair import generate_chart
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
4
  from transformers import pipeline
5
  st.title("Rap Lyrics Generator")
 
 
6
 
7
  model_ckpt = "flax-community/gpt2-rap-lyric-generator"
8
  tokenizer = AutoTokenizer.from_pretrained(model_ckpt,from_flax=True)
9
  model = AutoModelForCausalLM.from_pretrained(model_ckpt,from_flax=True)
10
  text_generation = pipeline("text-generation", model=model, tokenizer=tokenizer)
11
 
12
- artist = st.text_input("Enter the artist", "Eminem")
13
- song_name = st.text_input("Enter the desired song name", "Gas is going")
14
-
15
  if st.button("Generate lyrics"):
16
  st.title(f"{artist}: {song_name}")
17
  prefix_text = f"<BOS>{song_name} [Verse 1:{artist}]"
18
  generated_song = text_generation(prefix_text, max_length=500, do_sample=True)[0]
19
- for count, line in enumerate(generated_song['generated_text'].split("\n")):
 
20
  if count == 0:
21
  st.write(line[line.find('['):])
22
  continue
 
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
4
  from transformers import pipeline
5
  st.title("Rap Lyrics Generator")
6
+ artist = st.text_input("Enter the artist", "Eminem")
7
+ song_name = st.text_input("Enter the desired song name", "Gas is going")
8
 
9
  model_ckpt = "flax-community/gpt2-rap-lyric-generator"
10
  tokenizer = AutoTokenizer.from_pretrained(model_ckpt,from_flax=True)
11
  model = AutoModelForCausalLM.from_pretrained(model_ckpt,from_flax=True)
12
  text_generation = pipeline("text-generation", model=model, tokenizer=tokenizer)
13
 
 
 
 
14
  if st.button("Generate lyrics"):
15
  st.title(f"{artist}: {song_name}")
16
  prefix_text = f"<BOS>{song_name} [Verse 1:{artist}]"
17
  generated_song = text_generation(prefix_text, max_length=500, do_sample=True)[0]
18
+ for count, line in enumerate(generated_song['generated_text'].split("\
19
+ ")):
20
  if count == 0:
21
  st.write(line[line.find('['):])
22
  continue