niclasfw commited on
Commit
3022295
1 Parent(s): c6a10d8

Making app CPU compatible.

Browse files
Files changed (2) hide show
  1. app.py +8 -4
  2. requirements.txt +6 -3
app.py CHANGED
@@ -2,6 +2,7 @@ import streamlit as st
2
  import torch
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
 
 
5
  @st.cache(allow_output_mutation=True)
6
  def get_model():
7
  # load base LLM model and tokenizer
@@ -12,7 +13,7 @@ def get_model():
12
  model_id,
13
  low_cpu_mem_usage=True,
14
  torch_dtype=torch.float16,
15
- load_in_4bit=True,
16
  )
17
 
18
  return tokenizer, model
@@ -33,10 +34,13 @@ if user_input and button:
33
 
34
  ### Response:
35
  """
36
- input_ids = tokenizer(prompt, return_tensors="pt", truncation=True).input_ids
37
- outputs = model.generate(input_ids=input_ids, pad_token_id=tokenizer.eos_token_id, max_new_tokens=500, do_sample=True, top_p=0.75, temperature=0.95, top_k=15)
38
  st.write("Prompt: ", user_input)
 
 
 
 
 
39
  st.write("**************")
40
- st.write(outputs)
41
 
42
 
 
2
  import torch
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
 
5
+
6
  @st.cache(allow_output_mutation=True)
7
  def get_model():
8
  # load base LLM model and tokenizer
 
13
  model_id,
14
  low_cpu_mem_usage=True,
15
  torch_dtype=torch.float16,
16
+ # load_in_4bit=True,
17
  )
18
 
19
  return tokenizer, model
 
34
 
35
  ### Response:
36
  """
 
 
37
  st.write("Prompt: ", user_input)
38
+ input = tokenizer([prompt], padding=True, truncation=True, return_tensors="pt")
39
+ output = model(**input)
40
+ # input_ids = tokenizer(prompt, return_tensors="pt", truncation=True)
41
+ # outputs = model.generate(input_ids=input_ids, pad_token_id=tokenizer.eos_token_id, max_new_tokens=500, do_sample=True, top_p=0.75, temperature=0.95, top_k=15)
42
+
43
  st.write("**************")
44
+ st.write(output)
45
 
46
 
requirements.txt CHANGED
@@ -1,5 +1,8 @@
1
  streamlit
2
  torch
3
- transformers
4
- accelerate
5
- bitsandbytes
 
 
 
 
1
  streamlit
2
  torch
3
+ transformers==4.34.0
4
+ accelerate==0.23.0
5
+ bitsandbytes==0.41.1
6
+ trl==0.4.7
7
+ safetensors>=0.3.1
8
+ peft==0.5.0