niclasfw commited on
Commit
73004ed
1 Parent(s): 46c69ae

Updating app.py to allow for CPU only use.

Browse files
Files changed (1) hide show
  1. app.py +24 -19
app.py CHANGED
@@ -1,26 +1,31 @@
1
  import streamlit as st
2
  import torch
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
 
4
 
5
 
6
- @st.cache(allow_output_mutation=True)
7
- def get_model():
8
- # load base LLM model and tokenizer
9
 
10
- model_id = "niclasfw/schlager-bot-004"
11
- tokenizer = AutoTokenizer.from_pretrained(model_id)
12
- model = AutoModelForCausalLM.from_pretrained(
13
- model_id,
14
- # low_cpu_mem_usage=True,
15
- # torch_dtype=torch.float16,
16
- # load_in_4bit=True,
17
- )
18
 
19
- return tokenizer, model
20
 
21
- tokenizer, model = get_model()
22
 
23
- # st.title('Schlager Bot')
 
 
 
 
24
  user_input = st.text_area('Enter verse (minimum of 15 words): ')
25
  button = st.button('Generate Lyrics')
26
 
@@ -34,14 +39,14 @@ if user_input and button:
34
 
35
  ### Response:
36
  """
37
- st.write("Prompt: ", user_input)
38
- input = tokenizer(prompt, padding=True, return_tensors="pt")
39
- generate_ids = model.generate(input.input_ids, max_length=500, top_p=0.75, temperature=0.95, top_k=15)
40
- output = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
 
41
  # input_ids = tokenizer(prompt, return_tensors="pt", truncation=True)
42
  # outputs = model.generate(input_ids=input_ids, pad_token_id=tokenizer.eos_token_id, max_new_tokens=500, do_sample=True, top_p=0.75, temperature=0.95, top_k=15)
43
 
44
- st.write("**************")
45
  st.write(output)
46
 
47
 
 
1
  import streamlit as st
2
  import torch
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
+ from transformers import pipeline
5
 
6
 
7
+ # @st.cache(allow_output_mutation=True)
8
+ # def get_model():
9
+ # # load base LLM model and tokenizer
10
 
11
+ # model_id = "niclasfw/schlager-bot-004"
12
+ # tokenizer = AutoTokenizer.from_pretrained(model_id)
13
+ # model = AutoModelForCausalLM.from_pretrained(
14
+ # model_id,
15
+ # # low_cpu_mem_usage=True,
16
+ # # torch_dtype=torch.float16,
17
+ # # load_in_4bit=True,
18
+ # )
19
 
20
+ # return tokenizer, model
21
 
22
+ # tokenizer, model = get_model()
23
 
24
+ model_id = "niclasfw/schlager-bot-004"
25
+
26
+ generator = pipeline(task="text-generation", model=model_id, tokenizer=model_id)
27
+
28
+ st.title('Schlager Bot')
29
  user_input = st.text_area('Enter verse (minimum of 15 words): ')
30
  button = st.button('Generate Lyrics')
31
 
 
39
 
40
  ### Response:
41
  """
42
+ output = generator(prompt, do_sample=True, max_new_tokens=500, top_p=0.75, temperature=0.95, top_k=15)
43
+ # st.write("Prompt: ", user_input)
44
+ # input = tokenizer(prompt, padding=True, return_tensors="pt")
45
+ # generate_ids = model.generate(input.input_ids, max_length=500, top_p=0.75, temperature=0.95, top_k=15)
46
+ # output = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
47
  # input_ids = tokenizer(prompt, return_tensors="pt", truncation=True)
48
  # outputs = model.generate(input_ids=input_ids, pad_token_id=tokenizer.eos_token_id, max_new_tokens=500, do_sample=True, top_p=0.75, temperature=0.95, top_k=15)
49
 
 
50
  st.write(output)
51
 
52