masanorihirano commited on
Commit
c040907
1 Parent(s): 737268c

update script

Browse files
Files changed (2) hide show
  1. .gitignore +2 -0
  2. app.py +80 -59
.gitignore CHANGED
@@ -1,3 +1,5 @@
 
 
1
  .idea
2
  .env
3
  poetry.lock
 
1
+ secret.txt
2
+ slack_url.txt
3
  .idea
4
  .env
5
  poetry.lock
app.py CHANGED
@@ -6,6 +6,7 @@ from typing import Optional
6
  from typing import Tuple
7
 
8
  import gradio as gr
 
9
  import torch
10
  from fastchat.serve.inference import compress_module
11
  from fastchat.serve.inference import raise_warning_for_old_weights
@@ -31,6 +32,7 @@ BASE_MODEL = "decapoda-research/llama-13b-hf"
31
  LORA_WEIGHTS = "izumi-lab/llama-13b-japanese-lora-v0-1ep"
32
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
33
  DATASET_REPOSITORY = os.environ.get("DATASET_REPOSITORY", None)
 
34
 
35
  repo = None
36
  LOCAL_DIR = "/home/user/data/"
@@ -161,12 +163,65 @@ def evaluate(
161
  max_tokens=384,
162
  repetition_penalty=1.0,
163
  ):
164
- num_beams: int = 1
165
- top_p: float = 0.75
166
- top_k: int = 40
167
- prompt = generate_prompt(instruction, input)
168
- inputs = tokenizer(prompt, return_tensors="pt")
169
- if len(inputs["input_ids"][0]) > max_tokens + 10:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
  if HF_TOKEN and DATASET_REPOSITORY:
171
  try:
172
  now = datetime.datetime.now()
@@ -175,7 +230,7 @@ def evaluate(
175
  save_inputs_and_outputs(
176
  now,
177
  prompt,
178
- "",
179
  {
180
  "temperature": temperature,
181
  "top_p": top_p,
@@ -187,59 +242,27 @@ def evaluate(
187
  )
188
  except Exception as e:
189
  print(e)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
190
  return (
191
- f"please reduce the input length. Currently, {len(inputs['input_ids'][0])} tokens are used.",
192
  gr.update(interactive=True),
193
  gr.update(interactive=True),
194
  )
195
- input_ids = inputs["input_ids"].to(device)
196
- generation_config = GenerationConfig(
197
- do_sample=False,
198
- temperature=temperature,
199
- top_p=top_p,
200
- top_k=top_k,
201
- repetition_penalty=repetition_penalty,
202
- num_beams=num_beams,
203
- pad_token_id=tokenizer.pad_token_id,
204
- eos_token=tokenizer.eos_token_id,
205
- )
206
- with torch.no_grad():
207
- generation_output = model.generate(
208
- input_ids=input_ids,
209
- generation_config=generation_config,
210
- return_dict_in_generate=True,
211
- output_scores=True,
212
- max_new_tokens=max_tokens - len(input_ids),
213
- )
214
- s = generation_output.sequences[0]
215
- output = tokenizer.decode(s, skip_special_tokens=True)
216
- if prompt.endswith("Response:"):
217
- output = output.split("### Response:")[1].strip()
218
- elif prompt.endswith("返答:"):
219
- output = output.split("### 返答:")[1].strip()
220
- else:
221
- raise ValueError(f"No valid prompt ends. {prompt}")
222
- if HF_TOKEN and DATASET_REPOSITORY:
223
- try:
224
- now = datetime.datetime.now()
225
- current_time = now.strftime("%Y-%m-%d %H:%M:%S")
226
- print(f"[{current_time}] Pushing prompt and completion to the Hub")
227
- save_inputs_and_outputs(
228
- now,
229
- prompt,
230
- output,
231
- {
232
- "temperature": temperature,
233
- "top_p": top_p,
234
- "top_k": top_k,
235
- "num_beams": num_beams,
236
- "max_tokens": max_tokens,
237
- "repetition_penalty": repetition_penalty,
238
- },
239
- )
240
- except Exception as e:
241
- print(e)
242
- return output, gr.update(interactive=True), gr.update(interactive=True)
243
 
244
 
245
  def reset_textbox():
@@ -324,8 +347,6 @@ with gr.Blocks(
324
 
325
  Please note that this space utilizes [decapoda-research/llama-13b-hf](https://huggingface.co/decapoda-research/llama-13b-hf) and its special license is applied.
326
 
327
-
328
-
329
  ## データ収集、利用、共有に関するユーザーの同意:
330
  本アプリを使用することにより、提供するデータに関する以下の条件に同意するものとします:
331
 
@@ -367,5 +388,5 @@ with gr.Blocks(
367
  clear_button.click(reset_textbox, [], [instruction, inputs, outputs], queue=False)
368
 
369
  demo.queue(max_size=20, concurrency_count=NUM_THREADS, api_open=False).launch(
370
- share=True, server_name="0.0.0.0", server_port=7860
371
  )
 
6
  from typing import Tuple
7
 
8
  import gradio as gr
9
+ import requests
10
  import torch
11
  from fastchat.serve.inference import compress_module
12
  from fastchat.serve.inference import raise_warning_for_old_weights
 
32
  LORA_WEIGHTS = "izumi-lab/llama-13b-japanese-lora-v0-1ep"
33
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
34
  DATASET_REPOSITORY = os.environ.get("DATASET_REPOSITORY", None)
35
+ SLACK_WEBHOOK = os.environ.get("SLACK_WEBHOOK", None)
36
 
37
  repo = None
38
  LOCAL_DIR = "/home/user/data/"
 
163
  max_tokens=384,
164
  repetition_penalty=1.0,
165
  ):
166
+ try:
167
+ num_beams: int = 1
168
+ top_p: float = 0.75
169
+ top_k: int = 40
170
+ prompt = generate_prompt(instruction, input)
171
+ inputs = tokenizer(prompt, return_tensors="pt")
172
+ if len(inputs["input_ids"][0]) > max_tokens + 10:
173
+ if HF_TOKEN and DATASET_REPOSITORY:
174
+ try:
175
+ now = datetime.datetime.now()
176
+ current_time = now.strftime("%Y-%m-%d %H:%M:%S")
177
+ print(f"[{current_time}] Pushing prompt and completion to the Hub")
178
+ save_inputs_and_outputs(
179
+ now,
180
+ prompt,
181
+ "",
182
+ {
183
+ "temperature": temperature,
184
+ "top_p": top_p,
185
+ "top_k": top_k,
186
+ "num_beams": num_beams,
187
+ "max_tokens": max_tokens,
188
+ "repetition_penalty": repetition_penalty,
189
+ },
190
+ )
191
+ except Exception as e:
192
+ print(e)
193
+ return (
194
+ f"please reduce the input length. Currently, {len(inputs['input_ids'][0])} tokens are used.",
195
+ gr.update(interactive=True),
196
+ gr.update(interactive=True),
197
+ )
198
+ input_ids = inputs["input_ids"].to(device)
199
+ generation_config = GenerationConfig(
200
+ do_sample=False,
201
+ temperature=temperature,
202
+ top_p=top_p,
203
+ top_k=top_k,
204
+ repetition_penalty=repetition_penalty,
205
+ num_beams=num_beams,
206
+ pad_token_id=tokenizer.pad_token_id,
207
+ eos_token=tokenizer.eos_token_id,
208
+ )
209
+ with torch.no_grad():
210
+ generation_output = model.generate(
211
+ input_ids=input_ids,
212
+ generation_config=generation_config,
213
+ return_dict_in_generate=True,
214
+ output_scores=True,
215
+ max_new_tokens=max_tokens - len(input_ids),
216
+ )
217
+ s = generation_output.sequences[0]
218
+ output = tokenizer.decode(s, skip_special_tokens=True)
219
+ if prompt.endswith("Response:"):
220
+ output = output.split("### Response:")[1].strip()
221
+ elif prompt.endswith("返答:"):
222
+ output = output.split("### 返答:")[1].strip()
223
+ else:
224
+ raise ValueError(f"No valid prompt ends. {prompt}")
225
  if HF_TOKEN and DATASET_REPOSITORY:
226
  try:
227
  now = datetime.datetime.now()
 
230
  save_inputs_and_outputs(
231
  now,
232
  prompt,
233
+ output,
234
  {
235
  "temperature": temperature,
236
  "top_p": top_p,
 
242
  )
243
  except Exception as e:
244
  print(e)
245
+ return output, gr.update(interactive=True), gr.update(interactive=True)
246
+ except Exception as e:
247
+ print(e)
248
+ import traceback
249
+
250
+ if SLACK_WEBHOOK:
251
+ payload_dic = {
252
+ "text": f"BASE_MODEL: {BASE_MODEL}\n LORA_WEIGHTS: {LORA_WEIGHTS}\n"
253
+ + f"instruction: {instruction}\ninput: {input}\ntemperature: {temperature}\n"
254
+ + f"max_tokens: {max_tokens}\nrepetition_penalty: {repetition_penalty}\n\n"
255
+ + str(traceback.format_exc()),
256
+ "username": "Hugging Face Space",
257
+ "channel": "#monitor",
258
+ }
259
+
260
+ requests.post(SLACK_WEBHOOK, data=json.dumps(payload_dic))
261
  return (
262
+ "Error happend. Please return later.",
263
  gr.update(interactive=True),
264
  gr.update(interactive=True),
265
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
266
 
267
 
268
  def reset_textbox():
 
347
 
348
  Please note that this space utilizes [decapoda-research/llama-13b-hf](https://huggingface.co/decapoda-research/llama-13b-hf) and its special license is applied.
349
 
 
 
350
  ## データ収集、利用、共有に関するユーザーの同意:
351
  本アプリを使用することにより、提供するデータに関する以下の条件に同意するものとします:
352
 
 
388
  clear_button.click(reset_textbox, [], [instruction, inputs, outputs], queue=False)
389
 
390
  demo.queue(max_size=20, concurrency_count=NUM_THREADS, api_open=False).launch(
391
+ server_name="0.0.0.0", server_port=7860
392
  )