arieridwans commited on
Commit
07aebcc
1 Parent(s): 24f2e60

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -12
app.py CHANGED
@@ -15,7 +15,7 @@ model = AutoModelForCausalLM.from_pretrained(
15
  )
16
 
17
  # Streamlit UI
18
- st.title("Eleanor Rigby - Lyrics Generation")
19
 
20
  # User input prompt
21
  user_prompt = st.text_area("Enter your prompt that can be song lyrics:", """Yesterday, I saw you in my dream""")
@@ -24,16 +24,12 @@ user_prompt = st.text_area("Enter your prompt that can be song lyrics:", """Yest
24
  if st.button("Generate Output"):
25
  instruct_prompt = "Instruct:You are a song writer and your main reference is The Beatles. Write a song lyrics by completing these words:"
26
  output_prompt = "Output:"
27
- prompt = """ {0}{1}\n{2} """.format(instruct_prompt, user_prompt, output_prompt)
28
- with torch.no_grad():
29
- token_ids = tokenizer.encode(prompt, add_special_tokens=False, return_tensors="pt")
30
- output_ids = model.generate(
31
- token_ids.to(model.device),
32
- max_new_tokens=512,
33
- do_sample=True,
34
- temperature=0.3
35
- )
36
-
37
- output = tokenizer.decode(output_ids[0][token_ids.size(1):])
38
  st.text("Generated Result:")
39
  st.write(output)
 
15
  )
16
 
17
  # Streamlit UI
18
+ st.title("Eleanor Rigby")
19
 
20
  # User input prompt
21
  user_prompt = st.text_area("Enter your prompt that can be song lyrics:", """Yesterday, I saw you in my dream""")
 
24
  if st.button("Generate Output"):
25
  instruct_prompt = "Instruct:You are a song writer and your main reference is The Beatles. Write a song lyrics by completing these words:"
26
  output_prompt = "Output:"
27
+ input = inference_tokenizer(""" {0}{1}\n{2} """.format(instruct_prompt, user_prompt, output_prompt),
28
+ return_tensors="pt",
29
+ return_attention_mask=False,
30
+ padding=True,
31
+ truncation=True)
32
+ result = inference_model.generate(**input, repetition_penalty=1.2, max_length=1024)
33
+ output = inference_tokenizer.batch_decode(result, skip_special_tokens=True)[0]
 
 
 
 
34
  st.text("Generated Result:")
35
  st.write(output)