masanorihirano
commited on
Commit
•
dc15b84
1
Parent(s):
b3d63a6
update
Browse files
app.py
CHANGED
@@ -23,11 +23,14 @@ print("starting server ...")
|
|
23 |
|
24 |
BASE_MODEL = "decapoda-research/llama-13b-hf"
|
25 |
LORA_WEIGHTS = "izumi-lab/llama-13b-japanese-lora-v0-1ep"
|
26 |
-
DATASET_REPOSITORY = os.environ.get("DATASET_REPOSITORY", None)
|
27 |
HF_TOKEN = os.environ.get("HF_TOKEN", None)
|
|
|
28 |
|
29 |
repo = None
|
30 |
LOCAL_DIR = "/home/user/data/"
|
|
|
|
|
|
|
31 |
if HF_TOKEN and DATASET_REPOSITORY:
|
32 |
try:
|
33 |
shutil.rmtree(LOCAL_DIR)
|
@@ -42,7 +45,6 @@ if HF_TOKEN and DATASET_REPOSITORY:
|
|
42 |
)
|
43 |
repo.git_pull()
|
44 |
|
45 |
-
|
46 |
tokenizer = LlamaTokenizer.from_pretrained(BASE_MODEL)
|
47 |
|
48 |
if torch.cuda.is_available():
|
@@ -62,7 +64,7 @@ if device == "cuda":
|
|
62 |
load_in_8bit=True,
|
63 |
device_map="auto",
|
64 |
)
|
65 |
-
model = PeftModel.from_pretrained(model, LORA_WEIGHTS, load_in_8bit=True)
|
66 |
elif device == "mps":
|
67 |
model = AutoModelForCausalLM.from_pretrained(
|
68 |
BASE_MODEL,
|
@@ -77,10 +79,7 @@ elif device == "mps":
|
|
77 |
)
|
78 |
else:
|
79 |
model = AutoModelForCausalLM.from_pretrained(
|
80 |
-
BASE_MODEL,
|
81 |
-
device_map={"": device},
|
82 |
-
low_cpu_mem_usage=True,
|
83 |
-
load_in_8bit=True,
|
84 |
)
|
85 |
model = PeftModel.from_pretrained(
|
86 |
model,
|
@@ -91,18 +90,29 @@ else:
|
|
91 |
|
92 |
|
93 |
def generate_prompt(instruction: str, input: Optional[str] = None):
|
|
|
94 |
if input:
|
95 |
-
|
|
|
|
|
|
|
96 |
### Instruction:
|
97 |
{instruction}
|
98 |
### Input:
|
99 |
{input}
|
100 |
### Response:"""
|
|
|
|
|
101 |
else:
|
102 |
-
|
|
|
|
|
|
|
103 |
### Instruction:
|
104 |
{instruction}
|
105 |
### Response:"""
|
|
|
|
|
106 |
|
107 |
|
108 |
if device != "cpu":
|
@@ -114,7 +124,7 @@ if torch.__version__ >= "2":
|
|
114 |
|
115 |
def save_inputs_and_outputs(now, inputs, outputs, generate_kwargs):
|
116 |
current_hour = now.strftime("%Y-%m-%d_%H")
|
117 |
-
file_name = f"prompts_{current_hour}.jsonl"
|
118 |
|
119 |
if repo is not None:
|
120 |
repo.git_pull(rebase=True)
|
@@ -138,11 +148,11 @@ def evaluate(
|
|
138 |
instruction,
|
139 |
input=None,
|
140 |
temperature=0.7,
|
141 |
-
top_p=1.0,
|
142 |
-
top_k=40,
|
143 |
-
num_beams=4,
|
144 |
max_new_tokens=256,
|
145 |
):
|
|
|
|
|
|
|
146 |
prompt = generate_prompt(instruction, input)
|
147 |
inputs = tokenizer(prompt, return_tensors="pt")
|
148 |
input_ids = inputs["input_ids"].to(device)
|
@@ -151,6 +161,8 @@ def evaluate(
|
|
151 |
top_p=top_p,
|
152 |
top_k=top_k,
|
153 |
num_beams=num_beams,
|
|
|
|
|
154 |
)
|
155 |
with torch.no_grad():
|
156 |
generation_output = model.generate(
|
@@ -161,9 +173,14 @@ def evaluate(
|
|
161 |
max_new_tokens=max_new_tokens,
|
162 |
)
|
163 |
s = generation_output.sequences[0]
|
164 |
-
output = tokenizer.decode(s)
|
165 |
-
|
166 |
-
|
|
|
|
|
|
|
|
|
|
|
167 |
try:
|
168 |
now = datetime.datetime.now()
|
169 |
current_time = now.strftime("%Y-%m-%d %H:%M:%S")
|
@@ -215,10 +232,10 @@ with gr.Blocks(
|
|
215 |
clear_button = gr.Button("Clear").style(full_width=True)
|
216 |
with gr.Column(scale=5):
|
217 |
submit_button = gr.Button("Submit").style(full_width=True)
|
218 |
-
outputs = gr.Textbox(lines=
|
219 |
|
220 |
# inputs, top_p, temperature, top_k, repetition_penalty
|
221 |
-
with gr.Accordion("Parameters", open=
|
222 |
temperature = gr.Slider(
|
223 |
minimum=0,
|
224 |
maximum=1.0,
|
@@ -227,34 +244,10 @@ with gr.Blocks(
|
|
227 |
interactive=True,
|
228 |
label="Temperature",
|
229 |
)
|
230 |
-
top_p = gr.Slider(
|
231 |
-
minimum=0,
|
232 |
-
maximum=1.0,
|
233 |
-
value=1.0,
|
234 |
-
step=0.05,
|
235 |
-
interactive=True,
|
236 |
-
label="Top p",
|
237 |
-
)
|
238 |
-
top_k = gr.Slider(
|
239 |
-
minimum=1,
|
240 |
-
maximum=50,
|
241 |
-
value=4,
|
242 |
-
step=1,
|
243 |
-
interactive=True,
|
244 |
-
label="Top k",
|
245 |
-
)
|
246 |
-
num_beams = gr.Slider(
|
247 |
-
minimum=1,
|
248 |
-
maximum=50,
|
249 |
-
value=4,
|
250 |
-
step=1,
|
251 |
-
interactive=True,
|
252 |
-
label="Beams",
|
253 |
-
)
|
254 |
max_new_tokens = gr.Slider(
|
255 |
minimum=1,
|
256 |
-
maximum=
|
257 |
-
value=
|
258 |
step=1,
|
259 |
interactive=True,
|
260 |
label="Max length",
|
@@ -301,17 +294,17 @@ with gr.Blocks(
|
|
301 |
inputs.submit(no_interactive, [], [submit_button, clear_button])
|
302 |
inputs.submit(
|
303 |
evaluate,
|
304 |
-
[instruction, inputs, temperature,
|
305 |
[outputs, submit_button, clear_button],
|
306 |
)
|
307 |
submit_button.click(no_interactive, [], [submit_button, clear_button])
|
308 |
submit_button.click(
|
309 |
evaluate,
|
310 |
-
[instruction, inputs, temperature,
|
311 |
[outputs, submit_button, clear_button],
|
312 |
)
|
313 |
clear_button.click(reset_textbox, [], [instruction, inputs, outputs], queue=False)
|
314 |
|
315 |
demo.queue(max_size=20, concurrency_count=NUM_THREADS, api_open=False).launch(
|
316 |
-
|
317 |
)
|
|
|
23 |
|
24 |
BASE_MODEL = "decapoda-research/llama-13b-hf"
|
25 |
LORA_WEIGHTS = "izumi-lab/llama-13b-japanese-lora-v0-1ep"
|
|
|
26 |
HF_TOKEN = os.environ.get("HF_TOKEN", None)
|
27 |
+
DATASET_REPOSITORY = os.environ.get("DATASET_REPOSITORY", None)
|
28 |
|
29 |
repo = None
|
30 |
LOCAL_DIR = "/home/user/data/"
|
31 |
+
PROMPT_LANG = "en"
|
32 |
+
assert PROMPT_LANG in ["ja", "en"]
|
33 |
+
|
34 |
if HF_TOKEN and DATASET_REPOSITORY:
|
35 |
try:
|
36 |
shutil.rmtree(LOCAL_DIR)
|
|
|
45 |
)
|
46 |
repo.git_pull()
|
47 |
|
|
|
48 |
tokenizer = LlamaTokenizer.from_pretrained(BASE_MODEL)
|
49 |
|
50 |
if torch.cuda.is_available():
|
|
|
64 |
load_in_8bit=True,
|
65 |
device_map="auto",
|
66 |
)
|
67 |
+
model = PeftModel.from_pretrained(model, LORA_WEIGHTS, load_in_8bit=True,)
|
68 |
elif device == "mps":
|
69 |
model = AutoModelForCausalLM.from_pretrained(
|
70 |
BASE_MODEL,
|
|
|
79 |
)
|
80 |
else:
|
81 |
model = AutoModelForCausalLM.from_pretrained(
|
82 |
+
BASE_MODEL, device_map={"": device},load_in_8bit=True, low_cpu_mem_usage=True
|
|
|
|
|
|
|
83 |
)
|
84 |
model = PeftModel.from_pretrained(
|
85 |
model,
|
|
|
90 |
|
91 |
|
92 |
def generate_prompt(instruction: str, input: Optional[str] = None):
|
93 |
+
print(f"input: {input}")
|
94 |
if input:
|
95 |
+
if PROMPT_LANG == "ja":
|
96 |
+
return f"以下はタスクを説明する指示とさらなる文脈を適用する入力の組み合わせです。\n\n### 指示:\n{instruction}\n\n### 入力:\n{input}\n\n### Response:\n"
|
97 |
+
elif PROMPT_LANG == "en":
|
98 |
+
return f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
|
99 |
### Instruction:
|
100 |
{instruction}
|
101 |
### Input:
|
102 |
{input}
|
103 |
### Response:"""
|
104 |
+
else:
|
105 |
+
raise ValueError("PROMPT_LANG")
|
106 |
else:
|
107 |
+
if PROMPT_LANG == "ja":
|
108 |
+
return f"以下はタスクを説明する指示とさらなる文脈を適用する入力の組み合わせです。\n\n### 指示:\n{instruction}\n\n### 返答:\n"
|
109 |
+
elif PROMPT_LANG == "en":
|
110 |
+
return f"""Below is an instruction that describes a task. Write a response that appropriately completes the request.
|
111 |
### Instruction:
|
112 |
{instruction}
|
113 |
### Response:"""
|
114 |
+
else:
|
115 |
+
raise ValueError("PROMPT_LANG")
|
116 |
|
117 |
|
118 |
if device != "cpu":
|
|
|
124 |
|
125 |
def save_inputs_and_outputs(now, inputs, outputs, generate_kwargs):
|
126 |
current_hour = now.strftime("%Y-%m-%d_%H")
|
127 |
+
file_name = f"prompts_{LORA_WEIGHTS.split('/')[-1]}{current_hour}.jsonl"
|
128 |
|
129 |
if repo is not None:
|
130 |
repo.git_pull(rebase=True)
|
|
|
148 |
instruction,
|
149 |
input=None,
|
150 |
temperature=0.7,
|
|
|
|
|
|
|
151 |
max_new_tokens=256,
|
152 |
):
|
153 |
+
num_beams: int = 1
|
154 |
+
top_p: float = 1.0
|
155 |
+
top_k: int = 0
|
156 |
prompt = generate_prompt(instruction, input)
|
157 |
inputs = tokenizer(prompt, return_tensors="pt")
|
158 |
input_ids = inputs["input_ids"].to(device)
|
|
|
161 |
top_p=top_p,
|
162 |
top_k=top_k,
|
163 |
num_beams=num_beams,
|
164 |
+
pad_token_id=tokenizer.pad_token_id,
|
165 |
+
eos_token=tokenizer.eos_token_id,
|
166 |
)
|
167 |
with torch.no_grad():
|
168 |
generation_output = model.generate(
|
|
|
173 |
max_new_tokens=max_new_tokens,
|
174 |
)
|
175 |
s = generation_output.sequences[0]
|
176 |
+
output = tokenizer.decode(s, skip_special_tokens=True)
|
177 |
+
if prompt.endswith("Response:"):
|
178 |
+
output = output.split("### Response:")[1].strip()
|
179 |
+
elif prompt.endswith("返答:"):
|
180 |
+
output = output.split("### 返答:")[1].strip()
|
181 |
+
else:
|
182 |
+
raise ValueError(f"No valid prompt ends. {prompt}")
|
183 |
+
if HF_TOKEN:
|
184 |
try:
|
185 |
now = datetime.datetime.now()
|
186 |
current_time = now.strftime("%Y-%m-%d %H:%M:%S")
|
|
|
232 |
clear_button = gr.Button("Clear").style(full_width=True)
|
233 |
with gr.Column(scale=5):
|
234 |
submit_button = gr.Button("Submit").style(full_width=True)
|
235 |
+
outputs = gr.Textbox(lines=4, label="Output")
|
236 |
|
237 |
# inputs, top_p, temperature, top_k, repetition_penalty
|
238 |
+
with gr.Accordion("Parameters", open=True):
|
239 |
temperature = gr.Slider(
|
240 |
minimum=0,
|
241 |
maximum=1.0,
|
|
|
244 |
interactive=True,
|
245 |
label="Temperature",
|
246 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
247 |
max_new_tokens = gr.Slider(
|
248 |
minimum=1,
|
249 |
+
maximum=256,
|
250 |
+
value=128,
|
251 |
step=1,
|
252 |
interactive=True,
|
253 |
label="Max length",
|
|
|
294 |
inputs.submit(no_interactive, [], [submit_button, clear_button])
|
295 |
inputs.submit(
|
296 |
evaluate,
|
297 |
+
[instruction, inputs, temperature, max_new_tokens],
|
298 |
[outputs, submit_button, clear_button],
|
299 |
)
|
300 |
submit_button.click(no_interactive, [], [submit_button, clear_button])
|
301 |
submit_button.click(
|
302 |
evaluate,
|
303 |
+
[instruction, inputs, temperature, max_new_tokens],
|
304 |
[outputs, submit_button, clear_button],
|
305 |
)
|
306 |
clear_button.click(reset_textbox, [], [instruction, inputs, outputs], queue=False)
|
307 |
|
308 |
demo.queue(max_size=20, concurrency_count=NUM_THREADS, api_open=False).launch(
|
309 |
+
server_name="0.0.0.0", server_port=7860
|
310 |
)
|