niclasfw commited on
Commit
a088b25
1 Parent(s): c66873e

Adding code to download baseline model.

Browse files
Files changed (1) hide show
  1. app.py +4 -5
app.py CHANGED
@@ -1,5 +1,4 @@
1
  import streamlit as st
2
- import torch
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
  from transformers import pipeline
5
 
@@ -26,7 +25,7 @@ model_id = "niclasfw/schlager-bot-004"
26
  model = AutoModelForCausalLM.from_pretrained(model_id)
27
  tokenizer = AutoTokenizer.from_pretrained(model_id)
28
 
29
- generator = pipeline(task="text-generation", model=model_id, tokenizer=model_id)
30
 
31
  st.title('Schlager Bot')
32
  user_input = st.text_area('Enter verse (minimum of 15 words): ')
@@ -42,13 +41,13 @@ if user_input and button:
42
 
43
  ### Response:
44
  """
45
- output = generator(prompt, do_sample=True, max_new_tokens=500, top_p=0.75, temperature=0.95, top_k=15)
46
  # st.write("Prompt: ", user_input)
47
  # input = tokenizer(prompt, padding=True, return_tensors="pt")
48
  # generate_ids = model.generate(input.input_ids, max_length=500, top_p=0.75, temperature=0.95, top_k=15)
49
  # output = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
50
- # input_ids = tokenizer(prompt, return_tensors="pt", truncation=True)
51
- # outputs = model.generate(input_ids=input_ids, pad_token_id=tokenizer.eos_token_id, max_new_tokens=500, do_sample=True, top_p=0.75, temperature=0.95, top_k=15)
52
 
53
  st.write(output)
54
 
 
1
  import streamlit as st
 
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
3
  from transformers import pipeline
4
 
 
25
  model = AutoModelForCausalLM.from_pretrained(model_id)
26
  tokenizer = AutoTokenizer.from_pretrained(model_id)
27
 
28
+ # generator = pipeline(task="text-generation", model=model_id, tokenizer=model_id)
29
 
30
  st.title('Schlager Bot')
31
  user_input = st.text_area('Enter verse (minimum of 15 words): ')
 
41
 
42
  ### Response:
43
  """
44
+ # output = generator(prompt, do_sample=True, max_new_tokens=500, top_p=0.75, temperature=0.95, top_k=15)
45
  # st.write("Prompt: ", user_input)
46
  # input = tokenizer(prompt, padding=True, return_tensors="pt")
47
  # generate_ids = model.generate(input.input_ids, max_length=500, top_p=0.75, temperature=0.95, top_k=15)
48
  # output = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
49
+ input_ids = tokenizer(prompt, return_tensors="pt", truncation=True)
50
+ outputs = model.generate(input_ids=input_ids, pad_token_id=tokenizer.eos_token_id, max_new_tokens=500, do_sample=True, top_p=0.75, temperature=0.95, top_k=15)
51
 
52
  st.write(output)
53