Fta98 commited on
Commit
b3f598f
·
1 Parent(s): 20734fc
Files changed (1) hide show
  1. app.py +15 -8
app.py CHANGED
@@ -5,15 +5,22 @@ from transformers import AutoModelForCausalLM, LlamaTokenizer
5
  @st.cache_resource
6
  def load():
7
  """
8
- model = AutoModelForCausalLM.from_pretrained(
9
- "stabilityai/japanese-stablelm-instruct-alpha-7b",
10
- trust_remote_code=True,
11
- )
 
 
 
 
 
 
 
 
 
12
  """
13
- model = None
14
  tokenizer = LlamaTokenizer.from_pretrained(
15
- "novelai/nerdstash-tokenizer-v1",
16
- additional_special_tokens=['▁▁'],
17
  )
18
  return model, tokenizer
19
 
@@ -23,7 +30,7 @@ def get_prompt(user_query, system_prompt, messages="", sep="\n\n### "):
23
  msgs = [": \n" + user_query, ": "]
24
  if messages:
25
  roles.insert(1, "入力")
26
- msgs.insert(1, ": \n" + "\n".join(message for _, message in messages))
27
 
28
  for role, msg in zip(roles, msgs):
29
  prompt += sep + role + msg
 
5
  @st.cache_resource
6
  def load():
7
  """
8
+ base_model = AutoModelForCausalLM.from_pretrained(
9
+ "stabilityai/japanese-stablelm-instruct-alpha-7b",
10
+ device_map="auto",
11
+ low_cpu_mem_usage=True,
12
+ variant="int8",
13
+ load_in_8bit=True,
14
+ trust_remote_code=True,
15
+ )
16
+ model = PeftModel.from_pretrained(
17
+ base_model,
18
+ "lora_adapter",
19
+ device_map="auto",
20
+ )
21
  """
 
22
  tokenizer = LlamaTokenizer.from_pretrained(
23
+ "lora_adapter",
 
24
  )
25
  return model, tokenizer
26
 
 
30
  msgs = [": \n" + user_query, ": "]
31
  if messages:
32
  roles.insert(1, "入力")
33
+ msgs.insert(1, ": \n" + "\n".join(message["content"] for message in messages))
34
 
35
  for role, msg in zip(roles, msgs):
36
  prompt += sep + role + msg