Commit
•
6bde6cb
1
Parent(s):
dfd9622
WIP
Browse files- app.py +140 -137
- pyproject.toml +2 -2
app.py
CHANGED
@@ -4,21 +4,94 @@ import os
|
|
4 |
import shutil
|
5 |
from typing import Optional
|
6 |
from typing import Tuple
|
|
|
7 |
|
8 |
import gradio as gr
|
9 |
import requests
|
10 |
import torch
|
11 |
-
from fastchat.
|
12 |
-
from fastchat.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
from huggingface_hub import Repository
|
14 |
-
from huggingface_hub import hf_hub_download
|
15 |
from huggingface_hub import snapshot_download
|
16 |
from peft import LoraConfig
|
|
|
17 |
from peft import get_peft_model
|
18 |
from peft import set_peft_model_state_dict
|
19 |
from transformers import AutoModelForCausalLM
|
20 |
-
from transformers import
|
21 |
-
from transformers import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
|
23 |
print(datetime.datetime.now())
|
24 |
|
@@ -29,15 +102,15 @@ print(NUM_THREADS)
|
|
29 |
print("starting server ...")
|
30 |
|
31 |
BASE_MODEL = "decapoda-research/llama-13b-hf"
|
32 |
-
|
33 |
HF_TOKEN = os.environ.get("HF_TOKEN", None)
|
34 |
DATASET_REPOSITORY = os.environ.get("DATASET_REPOSITORY", None)
|
35 |
SLACK_WEBHOOK = os.environ.get("SLACK_WEBHOOK", None)
|
36 |
|
|
|
|
|
37 |
repo = None
|
38 |
LOCAL_DIR = "/home/user/data/"
|
39 |
-
PROMPT_LANG = "en"
|
40 |
-
assert PROMPT_LANG in ["ja", "en"]
|
41 |
|
42 |
if HF_TOKEN and DATASET_REPOSITORY:
|
43 |
try:
|
@@ -53,85 +126,34 @@ if HF_TOKEN and DATASET_REPOSITORY:
|
|
53 |
)
|
54 |
repo.git_pull()
|
55 |
|
56 |
-
tokenizer = LlamaTokenizer.from_pretrained(BASE_MODEL)
|
57 |
-
|
58 |
if torch.cuda.is_available():
|
59 |
device = "cuda"
|
60 |
else:
|
61 |
device = "cpu"
|
62 |
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
checkpoint_name = hf_hub_download(
|
73 |
-
repo_id=LORA_WEIGHTS, filename="adapter_model.bin", use_auth_token=HF_TOKEN
|
74 |
)
|
75 |
-
if device == "cuda":
|
76 |
-
model = AutoModelForCausalLM.from_pretrained(
|
77 |
-
BASE_MODEL, load_in_8bit=True, device_map="auto", torch_dtype=torch.float16
|
78 |
-
)
|
79 |
-
elif device == "mps":
|
80 |
-
model = AutoModelForCausalLM.from_pretrained(
|
81 |
-
BASE_MODEL,
|
82 |
-
device_map={"": device},
|
83 |
-
load_in_8bit=True,
|
84 |
-
torch_dtype=torch.float16,
|
85 |
-
)
|
86 |
-
else:
|
87 |
-
model = AutoModelForCausalLM.from_pretrained(
|
88 |
-
BASE_MODEL,
|
89 |
-
device_map={"": device},
|
90 |
-
load_in_8bit=True,
|
91 |
-
low_cpu_mem_usage=True,
|
92 |
-
torch_dtype=torch.float16,
|
93 |
-
)
|
94 |
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
set_peft_model_state_dict(model, adapters_weights)
|
99 |
-
raise_warning_for_old_weights(BASE_MODEL, model)
|
100 |
-
compress_module(model, device)
|
101 |
-
# if device == "cuda" or device == "mps":
|
102 |
-
# model = model.to(device)
|
103 |
-
|
104 |
-
|
105 |
-
def generate_prompt(instruction: str, input: Optional[str] = None):
|
106 |
-
if input:
|
107 |
-
if PROMPT_LANG == "ja":
|
108 |
-
return f"以下はタスクを説明する指示とさらなる文脈を適用する入力の組み合わせです。\n\n### 指示:\n{instruction}\n\n### 入力:\n{input}\n\n### Response:\n"
|
109 |
-
elif PROMPT_LANG == "en":
|
110 |
-
return f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
|
111 |
-
### Instruction:
|
112 |
-
{instruction}
|
113 |
-
### Input:
|
114 |
-
{input}
|
115 |
-
### Response:"""
|
116 |
-
else:
|
117 |
-
raise ValueError("PROMPT_LANG")
|
118 |
-
else:
|
119 |
-
if PROMPT_LANG == "ja":
|
120 |
-
return f"以下はタスクを説明する指示とさらなる文脈を適用する入力の組み合わせです。\n\n### 指示:\n{instruction}\n\n### 返答:\n"
|
121 |
-
elif PROMPT_LANG == "en":
|
122 |
-
return f"""Below is an instruction that describes a task. Write a response that appropriately completes the request.
|
123 |
-
### Instruction:
|
124 |
-
{instruction}
|
125 |
-
### Response:"""
|
126 |
-
else:
|
127 |
-
raise ValueError("PROMPT_LANG")
|
128 |
|
|
|
|
|
|
|
129 |
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
model = torch.compile(model)
|
135 |
|
136 |
|
137 |
def save_inputs_and_outputs(now, inputs, outputs, generate_kwargs):
|
@@ -158,20 +180,15 @@ def save_inputs_and_outputs(now, inputs, outputs, generate_kwargs):
|
|
158 |
# https://github.com/gradio-app/gradio/issues/3514
|
159 |
def evaluate(
|
160 |
instruction,
|
161 |
-
input=None,
|
162 |
temperature=0.7,
|
163 |
-
max_tokens=
|
164 |
repetition_penalty=1.0,
|
165 |
):
|
166 |
try:
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
top_k: int = 40
|
172 |
-
prompt = generate_prompt(instruction, input)
|
173 |
-
inputs = tokenizer(prompt, return_tensors="pt")
|
174 |
-
if len(inputs["input_ids"][0]) > max_tokens - 10:
|
175 |
if HF_TOKEN and DATASET_REPOSITORY:
|
176 |
try:
|
177 |
now = datetime.datetime.now()
|
@@ -179,13 +196,10 @@ def evaluate(
|
|
179 |
print(f"[{current_time}] Pushing prompt and completion to the Hub")
|
180 |
save_inputs_and_outputs(
|
181 |
now,
|
182 |
-
|
183 |
"",
|
184 |
{
|
185 |
"temperature": temperature,
|
186 |
-
"top_p": top_p,
|
187 |
-
"top_k": top_k,
|
188 |
-
"num_beams": num_beams,
|
189 |
"max_tokens": max_tokens,
|
190 |
"repetition_penalty": repetition_penalty,
|
191 |
},
|
@@ -193,37 +207,34 @@ def evaluate(
|
|
193 |
except Exception as e:
|
194 |
print(e)
|
195 |
return (
|
196 |
-
f"please reduce the input length. Currently, {len(inputs['input_ids'][0])} ( > {max_tokens -
|
197 |
gr.update(interactive=True),
|
198 |
gr.update(interactive=True),
|
199 |
)
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
output = output.split("### 返答:")[1].strip()
|
225 |
-
else:
|
226 |
-
raise ValueError(f"No valid prompt ends. {prompt}")
|
227 |
if HF_TOKEN and DATASET_REPOSITORY:
|
228 |
try:
|
229 |
now = datetime.datetime.now()
|
@@ -235,9 +246,6 @@ def evaluate(
|
|
235 |
output,
|
236 |
{
|
237 |
"temperature": temperature,
|
238 |
-
"top_p": top_p,
|
239 |
-
"top_k": top_k,
|
240 |
-
"num_beams": num_beams,
|
241 |
"max_tokens": max_tokens,
|
242 |
"repetition_penalty": repetition_penalty,
|
243 |
},
|
@@ -258,6 +266,7 @@ def evaluate(
|
|
258 |
"username": "Hugging Face Space",
|
259 |
"channel": "#monitor",
|
260 |
}
|
|
|
261 |
try:
|
262 |
requests.post(SLACK_WEBHOOK, data=json.dumps(payload_dic))
|
263 |
except Exception:
|
@@ -371,25 +380,19 @@ with gr.Blocks(
|
|
371 |
visible=True
|
372 |
)
|
373 |
|
374 |
-
|
375 |
-
|
376 |
-
|
377 |
-
|
378 |
-
|
379 |
-
|
380 |
-
inputs.submit(no_interactive, [], [submit_button, clear_button])
|
381 |
-
inputs.submit(
|
382 |
-
evaluate,
|
383 |
-
[instruction, inputs, temperature, max_tokens, repetition_penalty],
|
384 |
-
[outputs, submit_button, clear_button],
|
385 |
-
)
|
386 |
submit_button.click(no_interactive, [], [submit_button, clear_button])
|
387 |
submit_button.click(
|
388 |
evaluate,
|
389 |
-
[instruction,
|
390 |
[outputs, submit_button, clear_button],
|
391 |
)
|
392 |
-
clear_button.click(reset_textbox, [], [instruction,
|
393 |
|
394 |
demo.queue(max_size=20, concurrency_count=NUM_THREADS, api_open=False).launch(
|
395 |
server_name="0.0.0.0", server_port=7860
|
|
|
4 |
import shutil
|
5 |
from typing import Optional
|
6 |
from typing import Tuple
|
7 |
+
from typing import Union
|
8 |
|
9 |
import gradio as gr
|
10 |
import requests
|
11 |
import torch
|
12 |
+
from fastchat.conversation import Conversation
|
13 |
+
from fastchat.conversation import SeparatorStyle
|
14 |
+
from fastchat.conversation import get_conv_template
|
15 |
+
from fastchat.conversation import register_conv_template
|
16 |
+
from fastchat.model.model_adapter import BaseAdapter
|
17 |
+
from fastchat.model.model_adapter import load_model
|
18 |
+
from fastchat.model.model_adapter import model_adapters
|
19 |
+
from fastchat.serve.cli import SimpleChatIO
|
20 |
+
from fastchat.serve.inference import generate_stream
|
21 |
from huggingface_hub import Repository
|
|
|
22 |
from huggingface_hub import snapshot_download
|
23 |
from peft import LoraConfig
|
24 |
+
from peft import PeftModel
|
25 |
from peft import get_peft_model
|
26 |
from peft import set_peft_model_state_dict
|
27 |
from transformers import AutoModelForCausalLM
|
28 |
+
from transformers import AutoTokenizer
|
29 |
+
from transformers import PreTrainedModel
|
30 |
+
from transformers import PreTrainedTokenizerBase
|
31 |
+
|
32 |
+
|
33 |
+
class FastTokenizerAvailableBaseAdapter(BaseAdapter):
|
34 |
+
def load_model(self, model_path: str, from_pretrained_kwargs: dict):
|
35 |
+
try:
|
36 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
|
37 |
+
except ValueError:
|
38 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
39 |
+
model = AutoModelForCausalLM.from_pretrained(
|
40 |
+
model_path, low_cpu_mem_usage=True, **from_pretrained_kwargs
|
41 |
+
)
|
42 |
+
return model, tokenizer
|
43 |
+
|
44 |
+
|
45 |
+
model_adapters[-1] = FastTokenizerAvailableBaseAdapter()
|
46 |
+
|
47 |
+
|
48 |
+
def load_lora_model(
|
49 |
+
model_path: str,
|
50 |
+
lora_weight: str,
|
51 |
+
device: str,
|
52 |
+
num_gpus: int,
|
53 |
+
max_gpu_memory: Optional[str] = None,
|
54 |
+
load_8bit: bool = False,
|
55 |
+
cpu_offloading: bool = False,
|
56 |
+
debug: bool = False,
|
57 |
+
) -> Tuple[Union[PreTrainedModel, PeftModel], PreTrainedTokenizerBase]:
|
58 |
+
model: Union[PreTrainedModel, PeftModel]
|
59 |
+
tokenizer: PreTrainedTokenizerBase
|
60 |
+
model, tokenizer = load_model(
|
61 |
+
model_path=model_path,
|
62 |
+
device=device,
|
63 |
+
num_gpus=num_gpus,
|
64 |
+
max_gpu_memory=max_gpu_memory,
|
65 |
+
load_8bit=load_8bit,
|
66 |
+
cpu_offloading=cpu_offloading,
|
67 |
+
debug=debug,
|
68 |
+
)
|
69 |
+
if lora_weight is not None:
|
70 |
+
# model = PeftModelForCausalLM.from_pretrained(model, model_path, **kwargs)
|
71 |
+
config = LoraConfig.from_pretrained(lora_weight)
|
72 |
+
model = get_peft_model(model, config)
|
73 |
+
|
74 |
+
# Check the available weights and load them
|
75 |
+
checkpoint_name = os.path.join(
|
76 |
+
lora_weight, "pytorch_model.bin"
|
77 |
+
) # Full checkpoint
|
78 |
+
if not os.path.exists(checkpoint_name):
|
79 |
+
checkpoint_name = os.path.join(
|
80 |
+
lora_weight, "adapter_model.bin"
|
81 |
+
) # only LoRA model - LoRA config above has to fit
|
82 |
+
# The two files above have a different name depending on how they were saved,
|
83 |
+
# but are actually the same.
|
84 |
+
if os.path.exists(checkpoint_name):
|
85 |
+
adapters_weights = torch.load(checkpoint_name)
|
86 |
+
set_peft_model_state_dict(model, adapters_weights)
|
87 |
+
else:
|
88 |
+
raise IOError(f"Checkpoint {checkpoint_name} not found")
|
89 |
+
|
90 |
+
if debug:
|
91 |
+
print(model)
|
92 |
+
|
93 |
+
return model, tokenizer
|
94 |
+
|
95 |
|
96 |
print(datetime.datetime.now())
|
97 |
|
|
|
102 |
print("starting server ...")
|
103 |
|
104 |
BASE_MODEL = "decapoda-research/llama-13b-hf"
|
105 |
+
LORA_WEIGHTS_HF = "izumi-lab/llama-13b-japanese-lora-v0-1ep"
|
106 |
HF_TOKEN = os.environ.get("HF_TOKEN", None)
|
107 |
DATASET_REPOSITORY = os.environ.get("DATASET_REPOSITORY", None)
|
108 |
SLACK_WEBHOOK = os.environ.get("SLACK_WEBHOOK", None)
|
109 |
|
110 |
+
LORA_WEIGHTS = snapshot_download(LORA_WEIGHTS_HF)
|
111 |
+
|
112 |
repo = None
|
113 |
LOCAL_DIR = "/home/user/data/"
|
|
|
|
|
114 |
|
115 |
if HF_TOKEN and DATASET_REPOSITORY:
|
116 |
try:
|
|
|
126 |
)
|
127 |
repo.git_pull()
|
128 |
|
|
|
|
|
129 |
if torch.cuda.is_available():
|
130 |
device = "cuda"
|
131 |
else:
|
132 |
device = "cpu"
|
133 |
|
134 |
+
model, tokenizer = load_lora_model(
|
135 |
+
model_path=BASE_MODEL,
|
136 |
+
lora_weight=LORA_WEIGHTS,
|
137 |
+
device=device,
|
138 |
+
num_gpus=1,
|
139 |
+
max_gpu_memory="16GiB",
|
140 |
+
load_8bit=False,
|
141 |
+
cpu_offloading=False,
|
142 |
+
debug=False,
|
|
|
|
|
143 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
144 |
|
145 |
+
Conversation._get_prompt = Conversation.get_prompt
|
146 |
+
Conversation._append_message = Conversation.append_message
|
147 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
148 |
|
149 |
+
def conversation_append_message(cls, role: str, message: str):
|
150 |
+
cls.offset = -2
|
151 |
+
return cls._append_message(role, message)
|
152 |
|
153 |
+
|
154 |
+
def conversation_get_prompt_overrider(cls: Conversation) -> str:
|
155 |
+
cls.messages = cls.messages[-2:]
|
156 |
+
return cls._get_prompt()
|
|
|
157 |
|
158 |
|
159 |
def save_inputs_and_outputs(now, inputs, outputs, generate_kwargs):
|
|
|
180 |
# https://github.com/gradio-app/gradio/issues/3514
|
181 |
def evaluate(
|
182 |
instruction,
|
|
|
183 |
temperature=0.7,
|
184 |
+
max_tokens=256,
|
185 |
repetition_penalty=1.0,
|
186 |
):
|
187 |
try:
|
188 |
+
conv_template = "japanese"
|
189 |
+
|
190 |
+
inputs = tokenizer(instruction, return_tensors="pt")
|
191 |
+
if len(inputs["input_ids"][0]) > max_tokens - 40:
|
|
|
|
|
|
|
|
|
192 |
if HF_TOKEN and DATASET_REPOSITORY:
|
193 |
try:
|
194 |
now = datetime.datetime.now()
|
|
|
196 |
print(f"[{current_time}] Pushing prompt and completion to the Hub")
|
197 |
save_inputs_and_outputs(
|
198 |
now,
|
199 |
+
instruction,
|
200 |
"",
|
201 |
{
|
202 |
"temperature": temperature,
|
|
|
|
|
|
|
203 |
"max_tokens": max_tokens,
|
204 |
"repetition_penalty": repetition_penalty,
|
205 |
},
|
|
|
207 |
except Exception as e:
|
208 |
print(e)
|
209 |
return (
|
210 |
+
f"please reduce the input length. Currently, {len(inputs['input_ids'][0])} ( > {max_tokens - 40}) tokens are used.",
|
211 |
gr.update(interactive=True),
|
212 |
gr.update(interactive=True),
|
213 |
)
|
214 |
+
|
215 |
+
conv = get_conv_template(conv_template)
|
216 |
+
|
217 |
+
conv.append_message(conv.roles[0], instruction)
|
218 |
+
conv.append_message(conv.roles[1], None)
|
219 |
+
|
220 |
+
generate_stream_func = generate_stream
|
221 |
+
prompt = conv.get_prompt()
|
222 |
+
|
223 |
+
gen_params = {
|
224 |
+
"model": BASE_MODEL,
|
225 |
+
"prompt": prompt,
|
226 |
+
"temperature": temperature,
|
227 |
+
"max_new_tokens": max_tokens - len(inputs["input_ids"][0]) - 30,
|
228 |
+
"stop": conv.stop_str,
|
229 |
+
"stop_token_ids": conv.stop_token_ids,
|
230 |
+
"echo": False,
|
231 |
+
"repetition_penalty": repetition_penalty,
|
232 |
+
}
|
233 |
+
chatio = SimpleChatIO()
|
234 |
+
chatio.prompt_for_output(conv.roles[1])
|
235 |
+
output_stream = generate_stream_func(model, tokenizer, gen_params, device)
|
236 |
+
output = chatio.stream_output(output_stream)
|
237 |
+
|
|
|
|
|
|
|
238 |
if HF_TOKEN and DATASET_REPOSITORY:
|
239 |
try:
|
240 |
now = datetime.datetime.now()
|
|
|
246 |
output,
|
247 |
{
|
248 |
"temperature": temperature,
|
|
|
|
|
|
|
249 |
"max_tokens": max_tokens,
|
250 |
"repetition_penalty": repetition_penalty,
|
251 |
},
|
|
|
266 |
"username": "Hugging Face Space",
|
267 |
"channel": "#monitor",
|
268 |
}
|
269 |
+
|
270 |
try:
|
271 |
requests.post(SLACK_WEBHOOK, data=json.dumps(payload_dic))
|
272 |
except Exception:
|
|
|
380 |
visible=True
|
381 |
)
|
382 |
|
383 |
+
accept_button.click(
|
384 |
+
fn=enable_inputs,
|
385 |
+
inputs=[],
|
386 |
+
outputs=[user_consent_block, main_block],
|
387 |
+
queue=False,
|
388 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
389 |
submit_button.click(no_interactive, [], [submit_button, clear_button])
|
390 |
submit_button.click(
|
391 |
evaluate,
|
392 |
+
[instruction, temperature, max_tokens, repetition_penalty],
|
393 |
[outputs, submit_button, clear_button],
|
394 |
)
|
395 |
+
clear_button.click(reset_textbox, [], [instruction, outputs], queue=False)
|
396 |
|
397 |
demo.queue(max_size=20, concurrency_count=NUM_THREADS, api_open=False).launch(
|
398 |
server_name="0.0.0.0", server_port=7860
|
pyproject.toml
CHANGED
@@ -15,8 +15,8 @@ huggingface-hub = "^0.14.1"
|
|
15 |
sentencepiece = "^0.1.99"
|
16 |
bitsandbytes = "^0.38.1"
|
17 |
accelerate = "^0.19.0"
|
18 |
-
fschat = "
|
19 |
-
transformers = "
|
20 |
|
21 |
|
22 |
[tool.poetry.group.dev.dependencies]
|
|
|
15 |
sentencepiece = "^0.1.99"
|
16 |
bitsandbytes = "^0.38.1"
|
17 |
accelerate = "^0.19.0"
|
18 |
+
fschat = "0.2.8"
|
19 |
+
transformers = "4.28.1"
|
20 |
|
21 |
|
22 |
[tool.poetry.group.dev.dependencies]
|