masanorihirano commited on
Commit
6bde6cb
1 Parent(s): dfd9622
Files changed (2) hide show
  1. app.py +140 -137
  2. pyproject.toml +2 -2
app.py CHANGED
@@ -4,21 +4,94 @@ import os
4
  import shutil
5
  from typing import Optional
6
  from typing import Tuple
 
7
 
8
  import gradio as gr
9
  import requests
10
  import torch
11
- from fastchat.serve.inference import compress_module
12
- from fastchat.serve.inference import raise_warning_for_old_weights
 
 
 
 
 
 
 
13
  from huggingface_hub import Repository
14
- from huggingface_hub import hf_hub_download
15
  from huggingface_hub import snapshot_download
16
  from peft import LoraConfig
 
17
  from peft import get_peft_model
18
  from peft import set_peft_model_state_dict
19
  from transformers import AutoModelForCausalLM
20
- from transformers import GenerationConfig
21
- from transformers import LlamaTokenizer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
  print(datetime.datetime.now())
24
 
@@ -29,15 +102,15 @@ print(NUM_THREADS)
29
  print("starting server ...")
30
 
31
  BASE_MODEL = "decapoda-research/llama-13b-hf"
32
- LORA_WEIGHTS = "izumi-lab/llama-13b-japanese-lora-v0-1ep"
33
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
34
  DATASET_REPOSITORY = os.environ.get("DATASET_REPOSITORY", None)
35
  SLACK_WEBHOOK = os.environ.get("SLACK_WEBHOOK", None)
36
 
 
 
37
  repo = None
38
  LOCAL_DIR = "/home/user/data/"
39
- PROMPT_LANG = "en"
40
- assert PROMPT_LANG in ["ja", "en"]
41
 
42
  if HF_TOKEN and DATASET_REPOSITORY:
43
  try:
@@ -53,85 +126,34 @@ if HF_TOKEN and DATASET_REPOSITORY:
53
  )
54
  repo.git_pull()
55
 
56
- tokenizer = LlamaTokenizer.from_pretrained(BASE_MODEL)
57
-
58
  if torch.cuda.is_available():
59
  device = "cuda"
60
  else:
61
  device = "cpu"
62
 
63
- try:
64
- if torch.backends.mps.is_available():
65
- device = "mps"
66
- except Exception:
67
- pass
68
-
69
- resume_from_checkpoint = snapshot_download(
70
- repo_id=LORA_WEIGHTS, use_auth_token=HF_TOKEN
71
- )
72
- checkpoint_name = hf_hub_download(
73
- repo_id=LORA_WEIGHTS, filename="adapter_model.bin", use_auth_token=HF_TOKEN
74
  )
75
- if device == "cuda":
76
- model = AutoModelForCausalLM.from_pretrained(
77
- BASE_MODEL, load_in_8bit=True, device_map="auto", torch_dtype=torch.float16
78
- )
79
- elif device == "mps":
80
- model = AutoModelForCausalLM.from_pretrained(
81
- BASE_MODEL,
82
- device_map={"": device},
83
- load_in_8bit=True,
84
- torch_dtype=torch.float16,
85
- )
86
- else:
87
- model = AutoModelForCausalLM.from_pretrained(
88
- BASE_MODEL,
89
- device_map={"": device},
90
- load_in_8bit=True,
91
- low_cpu_mem_usage=True,
92
- torch_dtype=torch.float16,
93
- )
94
 
95
- config = LoraConfig.from_pretrained(resume_from_checkpoint)
96
- model = get_peft_model(model, config)
97
- adapters_weights = torch.load(checkpoint_name)
98
- set_peft_model_state_dict(model, adapters_weights)
99
- raise_warning_for_old_weights(BASE_MODEL, model)
100
- compress_module(model, device)
101
- # if device == "cuda" or device == "mps":
102
- # model = model.to(device)
103
-
104
-
105
- def generate_prompt(instruction: str, input: Optional[str] = None):
106
- if input:
107
- if PROMPT_LANG == "ja":
108
- return f"以下はタスクを説明する指示とさらなる文脈を適用する入力の組み合わせです。\n\n### 指示:\n{instruction}\n\n### 入力:\n{input}\n\n### Response:\n"
109
- elif PROMPT_LANG == "en":
110
- return f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
111
- ### Instruction:
112
- {instruction}
113
- ### Input:
114
- {input}
115
- ### Response:"""
116
- else:
117
- raise ValueError("PROMPT_LANG")
118
- else:
119
- if PROMPT_LANG == "ja":
120
- return f"以下はタスクを説明する指示とさらなる文脈を適用する入力の組み合わせです。\n\n### 指示:\n{instruction}\n\n### 返答:\n"
121
- elif PROMPT_LANG == "en":
122
- return f"""Below is an instruction that describes a task. Write a response that appropriately completes the request.
123
- ### Instruction:
124
- {instruction}
125
- ### Response:"""
126
- else:
127
- raise ValueError("PROMPT_LANG")
128
 
 
 
 
129
 
130
- if device != "cpu":
131
- model.half()
132
- model.eval()
133
- if torch.__version__ >= "2":
134
- model = torch.compile(model)
135
 
136
 
137
  def save_inputs_and_outputs(now, inputs, outputs, generate_kwargs):
@@ -158,20 +180,15 @@ def save_inputs_and_outputs(now, inputs, outputs, generate_kwargs):
158
  # https://github.com/gradio-app/gradio/issues/3514
159
  def evaluate(
160
  instruction,
161
- input=None,
162
  temperature=0.7,
163
- max_tokens=384,
164
  repetition_penalty=1.0,
165
  ):
166
  try:
167
- if temperature < 1e-8:
168
- temperature = 1e-8
169
- num_beams: int = 1
170
- top_p: float = 0.75
171
- top_k: int = 40
172
- prompt = generate_prompt(instruction, input)
173
- inputs = tokenizer(prompt, return_tensors="pt")
174
- if len(inputs["input_ids"][0]) > max_tokens - 10:
175
  if HF_TOKEN and DATASET_REPOSITORY:
176
  try:
177
  now = datetime.datetime.now()
@@ -179,13 +196,10 @@ def evaluate(
179
  print(f"[{current_time}] Pushing prompt and completion to the Hub")
180
  save_inputs_and_outputs(
181
  now,
182
- prompt,
183
  "",
184
  {
185
  "temperature": temperature,
186
- "top_p": top_p,
187
- "top_k": top_k,
188
- "num_beams": num_beams,
189
  "max_tokens": max_tokens,
190
  "repetition_penalty": repetition_penalty,
191
  },
@@ -193,37 +207,34 @@ def evaluate(
193
  except Exception as e:
194
  print(e)
195
  return (
196
- f"please reduce the input length. Currently, {len(inputs['input_ids'][0])} ( > {max_tokens - 10}) tokens are used.",
197
  gr.update(interactive=True),
198
  gr.update(interactive=True),
199
  )
200
- input_ids = inputs["input_ids"].to(device)
201
- generation_config = GenerationConfig(
202
- do_sample=False,
203
- temperature=temperature,
204
- top_p=top_p,
205
- top_k=top_k,
206
- repetition_penalty=repetition_penalty,
207
- num_beams=num_beams,
208
- pad_token_id=tokenizer.pad_token_id,
209
- eos_token=tokenizer.eos_token_id,
210
- )
211
- with torch.no_grad():
212
- generation_output = model.generate(
213
- input_ids=input_ids,
214
- generation_config=generation_config,
215
- return_dict_in_generate=True,
216
- output_scores=True,
217
- max_new_tokens=max_tokens - len(input_ids),
218
- )
219
- s = generation_output.sequences[0]
220
- output = tokenizer.decode(s, skip_special_tokens=True)
221
- if prompt.endswith("Response:"):
222
- output = output.split("### Response:")[1].strip()
223
- elif prompt.endswith("返答:"):
224
- output = output.split("### 返答:")[1].strip()
225
- else:
226
- raise ValueError(f"No valid prompt ends. {prompt}")
227
  if HF_TOKEN and DATASET_REPOSITORY:
228
  try:
229
  now = datetime.datetime.now()
@@ -235,9 +246,6 @@ def evaluate(
235
  output,
236
  {
237
  "temperature": temperature,
238
- "top_p": top_p,
239
- "top_k": top_k,
240
- "num_beams": num_beams,
241
  "max_tokens": max_tokens,
242
  "repetition_penalty": repetition_penalty,
243
  },
@@ -258,6 +266,7 @@ def evaluate(
258
  "username": "Hugging Face Space",
259
  "channel": "#monitor",
260
  }
 
261
  try:
262
  requests.post(SLACK_WEBHOOK, data=json.dumps(payload_dic))
263
  except Exception:
@@ -371,25 +380,19 @@ with gr.Blocks(
371
  visible=True
372
  )
373
 
374
- accept_button.click(
375
- fn=enable_inputs,
376
- inputs=[],
377
- outputs=[user_consent_block, main_block],
378
- queue=False,
379
- )
380
- inputs.submit(no_interactive, [], [submit_button, clear_button])
381
- inputs.submit(
382
- evaluate,
383
- [instruction, inputs, temperature, max_tokens, repetition_penalty],
384
- [outputs, submit_button, clear_button],
385
- )
386
  submit_button.click(no_interactive, [], [submit_button, clear_button])
387
  submit_button.click(
388
  evaluate,
389
- [instruction, inputs, temperature, max_tokens, repetition_penalty],
390
  [outputs, submit_button, clear_button],
391
  )
392
- clear_button.click(reset_textbox, [], [instruction, inputs, outputs], queue=False)
393
 
394
  demo.queue(max_size=20, concurrency_count=NUM_THREADS, api_open=False).launch(
395
  server_name="0.0.0.0", server_port=7860
 
4
  import shutil
5
  from typing import Optional
6
  from typing import Tuple
7
+ from typing import Union
8
 
9
  import gradio as gr
10
  import requests
11
  import torch
12
+ from fastchat.conversation import Conversation
13
+ from fastchat.conversation import SeparatorStyle
14
+ from fastchat.conversation import get_conv_template
15
+ from fastchat.conversation import register_conv_template
16
+ from fastchat.model.model_adapter import BaseAdapter
17
+ from fastchat.model.model_adapter import load_model
18
+ from fastchat.model.model_adapter import model_adapters
19
+ from fastchat.serve.cli import SimpleChatIO
20
+ from fastchat.serve.inference import generate_stream
21
  from huggingface_hub import Repository
 
22
  from huggingface_hub import snapshot_download
23
  from peft import LoraConfig
24
+ from peft import PeftModel
25
  from peft import get_peft_model
26
  from peft import set_peft_model_state_dict
27
  from transformers import AutoModelForCausalLM
28
+ from transformers import AutoTokenizer
29
+ from transformers import PreTrainedModel
30
+ from transformers import PreTrainedTokenizerBase
31
+
32
+
33
+ class FastTokenizerAvailableBaseAdapter(BaseAdapter):
34
+ def load_model(self, model_path: str, from_pretrained_kwargs: dict):
35
+ try:
36
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
37
+ except ValueError:
38
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
39
+ model = AutoModelForCausalLM.from_pretrained(
40
+ model_path, low_cpu_mem_usage=True, **from_pretrained_kwargs
41
+ )
42
+ return model, tokenizer
43
+
44
+
45
+ model_adapters[-1] = FastTokenizerAvailableBaseAdapter()
46
+
47
+
48
+ def load_lora_model(
49
+ model_path: str,
50
+ lora_weight: str,
51
+ device: str,
52
+ num_gpus: int,
53
+ max_gpu_memory: Optional[str] = None,
54
+ load_8bit: bool = False,
55
+ cpu_offloading: bool = False,
56
+ debug: bool = False,
57
+ ) -> Tuple[Union[PreTrainedModel, PeftModel], PreTrainedTokenizerBase]:
58
+ model: Union[PreTrainedModel, PeftModel]
59
+ tokenizer: PreTrainedTokenizerBase
60
+ model, tokenizer = load_model(
61
+ model_path=model_path,
62
+ device=device,
63
+ num_gpus=num_gpus,
64
+ max_gpu_memory=max_gpu_memory,
65
+ load_8bit=load_8bit,
66
+ cpu_offloading=cpu_offloading,
67
+ debug=debug,
68
+ )
69
+ if lora_weight is not None:
70
+ # model = PeftModelForCausalLM.from_pretrained(model, model_path, **kwargs)
71
+ config = LoraConfig.from_pretrained(lora_weight)
72
+ model = get_peft_model(model, config)
73
+
74
+ # Check the available weights and load them
75
+ checkpoint_name = os.path.join(
76
+ lora_weight, "pytorch_model.bin"
77
+ ) # Full checkpoint
78
+ if not os.path.exists(checkpoint_name):
79
+ checkpoint_name = os.path.join(
80
+ lora_weight, "adapter_model.bin"
81
+ ) # only LoRA model - LoRA config above has to fit
82
+ # The two files above have a different name depending on how they were saved,
83
+ # but are actually the same.
84
+ if os.path.exists(checkpoint_name):
85
+ adapters_weights = torch.load(checkpoint_name)
86
+ set_peft_model_state_dict(model, adapters_weights)
87
+ else:
88
+ raise IOError(f"Checkpoint {checkpoint_name} not found")
89
+
90
+ if debug:
91
+ print(model)
92
+
93
+ return model, tokenizer
94
+
95
 
96
  print(datetime.datetime.now())
97
 
 
102
  print("starting server ...")
103
 
104
  BASE_MODEL = "decapoda-research/llama-13b-hf"
105
+ LORA_WEIGHTS_HF = "izumi-lab/llama-13b-japanese-lora-v0-1ep"
106
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
107
  DATASET_REPOSITORY = os.environ.get("DATASET_REPOSITORY", None)
108
  SLACK_WEBHOOK = os.environ.get("SLACK_WEBHOOK", None)
109
 
110
+ LORA_WEIGHTS = snapshot_download(LORA_WEIGHTS_HF)
111
+
112
  repo = None
113
  LOCAL_DIR = "/home/user/data/"
 
 
114
 
115
  if HF_TOKEN and DATASET_REPOSITORY:
116
  try:
 
126
  )
127
  repo.git_pull()
128
 
 
 
129
  if torch.cuda.is_available():
130
  device = "cuda"
131
  else:
132
  device = "cpu"
133
 
134
+ model, tokenizer = load_lora_model(
135
+ model_path=BASE_MODEL,
136
+ lora_weight=LORA_WEIGHTS,
137
+ device=device,
138
+ num_gpus=1,
139
+ max_gpu_memory="16GiB",
140
+ load_8bit=False,
141
+ cpu_offloading=False,
142
+ debug=False,
 
 
143
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
 
145
+ Conversation._get_prompt = Conversation.get_prompt
146
+ Conversation._append_message = Conversation.append_message
147
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
 
149
+ def conversation_append_message(cls, role: str, message: str):
150
+ cls.offset = -2
151
+ return cls._append_message(role, message)
152
 
153
+
154
+ def conversation_get_prompt_overrider(cls: Conversation) -> str:
155
+ cls.messages = cls.messages[-2:]
156
+ return cls._get_prompt()
 
157
 
158
 
159
  def save_inputs_and_outputs(now, inputs, outputs, generate_kwargs):
 
180
  # https://github.com/gradio-app/gradio/issues/3514
181
  def evaluate(
182
  instruction,
 
183
  temperature=0.7,
184
+ max_tokens=256,
185
  repetition_penalty=1.0,
186
  ):
187
  try:
188
+ conv_template = "japanese"
189
+
190
+ inputs = tokenizer(instruction, return_tensors="pt")
191
+ if len(inputs["input_ids"][0]) > max_tokens - 40:
 
 
 
 
192
  if HF_TOKEN and DATASET_REPOSITORY:
193
  try:
194
  now = datetime.datetime.now()
 
196
  print(f"[{current_time}] Pushing prompt and completion to the Hub")
197
  save_inputs_and_outputs(
198
  now,
199
+ instruction,
200
  "",
201
  {
202
  "temperature": temperature,
 
 
 
203
  "max_tokens": max_tokens,
204
  "repetition_penalty": repetition_penalty,
205
  },
 
207
  except Exception as e:
208
  print(e)
209
  return (
210
+ f"please reduce the input length. Currently, {len(inputs['input_ids'][0])} ( > {max_tokens - 40}) tokens are used.",
211
  gr.update(interactive=True),
212
  gr.update(interactive=True),
213
  )
214
+
215
+ conv = get_conv_template(conv_template)
216
+
217
+ conv.append_message(conv.roles[0], instruction)
218
+ conv.append_message(conv.roles[1], None)
219
+
220
+ generate_stream_func = generate_stream
221
+ prompt = conv.get_prompt()
222
+
223
+ gen_params = {
224
+ "model": BASE_MODEL,
225
+ "prompt": prompt,
226
+ "temperature": temperature,
227
+ "max_new_tokens": max_tokens - len(inputs["input_ids"][0]) - 30,
228
+ "stop": conv.stop_str,
229
+ "stop_token_ids": conv.stop_token_ids,
230
+ "echo": False,
231
+ "repetition_penalty": repetition_penalty,
232
+ }
233
+ chatio = SimpleChatIO()
234
+ chatio.prompt_for_output(conv.roles[1])
235
+ output_stream = generate_stream_func(model, tokenizer, gen_params, device)
236
+ output = chatio.stream_output(output_stream)
237
+
 
 
 
238
  if HF_TOKEN and DATASET_REPOSITORY:
239
  try:
240
  now = datetime.datetime.now()
 
246
  output,
247
  {
248
  "temperature": temperature,
 
 
 
249
  "max_tokens": max_tokens,
250
  "repetition_penalty": repetition_penalty,
251
  },
 
266
  "username": "Hugging Face Space",
267
  "channel": "#monitor",
268
  }
269
+
270
  try:
271
  requests.post(SLACK_WEBHOOK, data=json.dumps(payload_dic))
272
  except Exception:
 
380
  visible=True
381
  )
382
 
383
+ accept_button.click(
384
+ fn=enable_inputs,
385
+ inputs=[],
386
+ outputs=[user_consent_block, main_block],
387
+ queue=False,
388
+ )
 
 
 
 
 
 
389
  submit_button.click(no_interactive, [], [submit_button, clear_button])
390
  submit_button.click(
391
  evaluate,
392
+ [instruction, temperature, max_tokens, repetition_penalty],
393
  [outputs, submit_button, clear_button],
394
  )
395
+ clear_button.click(reset_textbox, [], [instruction, outputs], queue=False)
396
 
397
  demo.queue(max_size=20, concurrency_count=NUM_THREADS, api_open=False).launch(
398
  server_name="0.0.0.0", server_port=7860
pyproject.toml CHANGED
@@ -15,8 +15,8 @@ huggingface-hub = "^0.14.1"
15
  sentencepiece = "^0.1.99"
16
  bitsandbytes = "^0.38.1"
17
  accelerate = "^0.19.0"
18
- fschat = "^0.2.3"
19
- transformers = "^4.29.2"
20
 
21
 
22
  [tool.poetry.group.dev.dependencies]
 
15
  sentencepiece = "^0.1.99"
16
  bitsandbytes = "^0.38.1"
17
  accelerate = "^0.19.0"
18
+ fschat = "0.2.8"
19
+ transformers = "4.28.1"
20
 
21
 
22
  [tool.poetry.group.dev.dependencies]