nxphi47 commited on
Commit
9e8444e
1 Parent(s): 52d5bca

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +91 -78
app.py CHANGED
@@ -973,6 +973,53 @@ gr.ChatInterface._setup_stop_events = _setup_stop_events
973
  gr.ChatInterface._setup_events = _setup_events
974
 
975
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
976
  def vllm_abort(self: Any):
977
  sh = self.llm_engine.scheduler
978
  for g in (sh.waiting + sh.running + sh.swapped):
@@ -1251,7 +1298,7 @@ def format_conversation(history):
1251
 
1252
  def maybe_upload_to_dataset():
1253
  global LOG_FILE, DATA_SET_REPO_PATH, SAVE_LOGS
1254
- if SAVE_LOGS and os.path.exists(LOG_PATH) and DATA_SET_REPO_PATH is not "":
1255
  with open(LOG_PATH, 'r', encoding='utf-8') as f:
1256
  convos = {}
1257
  for l in f:
@@ -1350,6 +1397,7 @@ def maybe_delete_folder():
1350
  except Exception as e:
1351
  print('Failed to delete %s. Reason: %s' % (file_path, e))
1352
 
 
1353
  AGREE_POP_SCRIPTS = """
1354
  async () => {
1355
  alert("To use our service, you are required to agree to the following terms:\\nYou must not use our service to generate any harmful, unethical or illegal content that violates local and international laws, including but not limited to hate speech, violence and deception.\\nThe service may collect user dialogue data for performance improvement, and reserves the right to distribute it under CC-BY or similar license. So do not enter any personal information!");
@@ -1366,6 +1414,7 @@ def debug_file_function(
1366
  stop_strings: str = "[STOP],<s>,</s>",
1367
  current_time: Optional[float] = None,
1368
  ):
 
1369
  files = files if isinstance(files, list) else [files]
1370
  print(files)
1371
  filenames = [f.name for f in files]
@@ -1391,7 +1440,9 @@ def debug_file_function(
1391
 
1392
 
1393
  def validate_file_item(filename, index, item: Dict[str, str]):
1394
- # BATCH_INFER_MAX_PROMPT_TOKENS
 
 
1395
  message = item['prompt'].strip()
1396
 
1397
  if len(message) == 0:
@@ -1399,7 +1450,7 @@ def validate_file_item(filename, index, item: Dict[str, str]):
1399
 
1400
  message_safety = safety_check(message, history=None)
1401
  if message_safety is not None:
1402
- raise gr.Error(f'Prompt {index} unsafe or supported: {message_safety}')
1403
 
1404
  tokenizer = llm.get_tokenizer() if llm is not None else None
1405
  if tokenizer is None or len(tokenizer.encode(message, add_special_tokens=False)) >= BATCH_INFER_MAX_PROMPT_TOKENS:
@@ -1423,25 +1474,33 @@ def read_validate_json_files(files: Union[str, List[str]]):
1423
  validate_file_item(fname, i, x)
1424
 
1425
  all_items.extend(items)
 
1426
  if len(all_items) > BATCH_INFER_MAX_ITEMS:
1427
  raise gr.Error(f"Num samples {len(all_items)} > {BATCH_INFER_MAX_ITEMS} allowed.")
1428
 
1429
- return all_items
1430
 
1431
 
1432
- def remove_gradio_cache():
 
1433
  import shutil
1434
  for root, dirs, files in os.walk('/tmp/gradio/'):
1435
  for f in files:
1436
- os.unlink(os.path.join(root, f))
1437
- for d in dirs:
1438
- shutil.rmtree(os.path.join(root, d))
 
 
 
 
 
 
1439
 
1440
 
1441
  def maybe_upload_batch_set(pred_json_path):
1442
  global LOG_FILE, DATA_SET_REPO_PATH, SAVE_LOGS
1443
 
1444
- if SAVE_LOGS and DATA_SET_REPO_PATH is not "":
1445
  try:
1446
  from huggingface_hub import upload_file
1447
  path_in_repo = "misc/" + os.path.basename(pred_json_path).replace(".json", f'.{time.time()}.json')
@@ -1470,7 +1529,7 @@ def batch_inference(
1470
  system_prompt: Optional[str] = SYSTEM_PROMPT_1
1471
  ):
1472
  """
1473
- Must handle
1474
 
1475
  """
1476
  global LOG_FILE, LOG_PATH, DEBUG, llm, RES_PRINTED
@@ -1493,11 +1552,10 @@ def batch_inference(
1493
  frequency_penalty = float(frequency_penalty)
1494
  max_tokens = int(max_tokens)
1495
 
1496
- all_items = read_validate_json_files(files)
1497
 
1498
  # remove all items in /tmp/gradio/
1499
- remove_gradio_cache()
1500
-
1501
 
1502
  if prompt_mode == 'chat':
1503
  prompt_format_fn = llama_chat_multiturn_sys_input_seq_constructor
@@ -1552,60 +1610,14 @@ def batch_inference(
1552
 
1553
 
1554
  # BATCH_INFER_MAX_ITEMS
1555
- FILE_UPLOAD_DESC = f"""File upload json format, with JSON object as list of dict with < {BATCH_INFER_MAX_ITEMS} items"""
1556
- FILE_UPLOAD_DESCRIPTION = FILE_UPLOAD_DESC + """
1557
  ```
1558
- [ {\"id\": 0, \"prompt\": \"Hello world\"} , {\"id\": 1, \"prompt\": \"Hi there?\"}]
1559
  ```
1560
  """
1561
 
1562
 
1563
- # https://huggingface.co/spaces/yuntian-deng/ChatGPT4Turbo/blob/main/app.py
1564
- @document()
1565
- class CusTabbedInterface(gr.Blocks):
1566
- def __init__(
1567
- self,
1568
- interface_list: list[gr.Interface],
1569
- tab_names: Optional[list[str]] = None,
1570
- title: Optional[str] = None,
1571
- description: Optional[str] = None,
1572
- theme: Optional[gr.Theme] = None,
1573
- analytics_enabled: Optional[bool] = None,
1574
- css: Optional[str] = None,
1575
- ):
1576
- """
1577
- Parameters:
1578
- interface_list: a list of interfaces to be rendered in tabs.
1579
- tab_names: a list of tab names. If None, the tab names will be "Tab 1", "Tab 2", etc.
1580
- title: a title for the interface; if provided, appears above the input and output components in large font. Also used as the tab title when opened in a browser window.
1581
- analytics_enabled: whether to allow basic telemetry. If None, will use GRADIO_ANALYTICS_ENABLED environment variable or default to True.
1582
- css: custom css or path to custom css file to apply to entire Blocks
1583
- Returns:
1584
- a Gradio Tabbed Interface for the given interfaces
1585
- """
1586
- super().__init__(
1587
- title=title or "Gradio",
1588
- theme=theme,
1589
- analytics_enabled=analytics_enabled,
1590
- mode="tabbed_interface",
1591
- css=css,
1592
- )
1593
- self.description = description
1594
- if tab_names is None:
1595
- tab_names = [f"Tab {i}" for i in range(len(interface_list))]
1596
- with self:
1597
- if title:
1598
- gr.Markdown(
1599
- f"<h1 style='text-align: center; margin-bottom: 1rem'>{title}</h1>"
1600
- )
1601
- if description:
1602
- gr.Markdown(description)
1603
- with gr.Tabs():
1604
- for interface, tab_name in zip(interface_list, tab_names):
1605
- with gr.Tab(label=tab_name):
1606
- interface.render()
1607
-
1608
-
1609
  def launch():
1610
  global demo, llm, DEBUG, LOG_FILE
1611
  model_desc = MODEL_DESC
@@ -1703,33 +1715,33 @@ def launch():
1703
 
1704
  if ENABLE_BATCH_INFER:
1705
 
1706
- demo_file = gr.Interface(
1707
  batch_inference,
1708
  inputs=[
1709
  gr.File(file_count='single', file_types=['json']),
1710
  gr.Radio(["chat", "few-shot"], value='chat', label="Chat or Few-shot mode", info="Chat's output more user-friendly, Few-shot's output more consistent with few-shot patterns."),
1711
- gr.Number(value=temperature, label='Temperature (higher -> more random)'),
1712
- gr.Number(value=max_tokens, label='Max generated tokens (increase if want more generation)'),
1713
- gr.Number(value=frequence_penalty, label='Frequency penalty (> 0 encourage new tokens over repeated tokens)'),
1714
- gr.Number(value=presence_penalty, label='Presence penalty (> 0 encourage new tokens, < 0 encourage existing tokens)'),
1715
- gr.Textbox(value="[STOP],[END],<s>,</s>", label='Comma-separated STOP string to stop generation only in few-shot mode', lines=1),
1716
  gr.Number(value=0, label='current_time', visible=False),
1717
  ],
1718
  outputs=[
1719
  # "file",
1720
  gr.File(label="Generated file"),
1721
- # gr.Textbox(),
1722
  # "json"
1723
- gr.JSON(label='Example outputs (max 2 samples)')
1724
  ],
1725
- # examples=[[[os.path.join(os.path.dirname(__file__),"files/titanic.csv"),
1726
- # os.path.join(os.path.dirname(__file__),"files/titanic.csv"),
1727
- # os.path.join(os.path.dirname(__file__),"files/titanic.csv")]]],
1728
- # cache_examples=True
1729
- description=FILE_UPLOAD_DESCRIPTION
 
 
1730
  )
1731
 
1732
-
1733
  demo_chat = gr.ChatInterface(
1734
  response_fn,
1735
  chatbot=ChatBot(
@@ -1757,8 +1769,8 @@ def launch():
1757
  # gr.Textbox(value=sys_prompt, label='System prompt', lines=8)
1758
  ],
1759
  )
1760
- demo = CusTabbedInterface(
1761
- interface_list=[demo_chat, demo_file],
1762
  tab_names=["Chat Interface", "Batch Inference"],
1763
  title=f"{model_title}",
1764
  description=f"{model_desc}",
@@ -1824,3 +1836,4 @@ def main():
1824
 
1825
  if __name__ == "__main__":
1826
  main()
 
 
973
  gr.ChatInterface._setup_events = _setup_events
974
 
975
 
976
+
977
+ @document()
978
+ class CustomTabbedInterface(gr.Blocks):
979
+ def __init__(
980
+ self,
981
+ interface_list: list[gr.Interface],
982
+ tab_names: Optional[list[str]] = None,
983
+ title: Optional[str] = None,
984
+ description: Optional[str] = None,
985
+ theme: Optional[gr.Theme] = None,
986
+ analytics_enabled: Optional[bool] = None,
987
+ css: Optional[str] = None,
988
+ ):
989
+ """
990
+ Parameters:
991
+ interface_list: a list of interfaces to be rendered in tabs.
992
+ tab_names: a list of tab names. If None, the tab names will be "Tab 1", "Tab 2", etc.
993
+ title: a title for the interface; if provided, appears above the input and output components in large font. Also used as the tab title when opened in a browser window.
994
+ analytics_enabled: whether to allow basic telemetry. If None, will use GRADIO_ANALYTICS_ENABLED environment variable or default to True.
995
+ css: custom css or path to custom css file to apply to entire Blocks
996
+ Returns:
997
+ a Gradio Tabbed Interface for the given interfaces
998
+ """
999
+ super().__init__(
1000
+ title=title or "Gradio",
1001
+ theme=theme,
1002
+ analytics_enabled=analytics_enabled,
1003
+ mode="tabbed_interface",
1004
+ css=css,
1005
+ )
1006
+ self.description = description
1007
+ if tab_names is None:
1008
+ tab_names = [f"Tab {i}" for i in range(len(interface_list))]
1009
+ with self:
1010
+ if title:
1011
+ gr.Markdown(
1012
+ f"<h1 style='text-align: center; margin-bottom: 1rem'>{title}</h1>"
1013
+ )
1014
+ if description:
1015
+ gr.Markdown(description)
1016
+ with gr.Tabs():
1017
+ for interface, tab_name in zip(interface_list, tab_names):
1018
+ with gr.Tab(label=tab_name):
1019
+ interface.render()
1020
+
1021
+
1022
+
1023
  def vllm_abort(self: Any):
1024
  sh = self.llm_engine.scheduler
1025
  for g in (sh.waiting + sh.running + sh.swapped):
 
1298
 
1299
  def maybe_upload_to_dataset():
1300
  global LOG_FILE, DATA_SET_REPO_PATH, SAVE_LOGS
1301
+ if SAVE_LOGS and os.path.exists(LOG_PATH) and DATA_SET_REPO_PATH != "":
1302
  with open(LOG_PATH, 'r', encoding='utf-8') as f:
1303
  convos = {}
1304
  for l in f:
 
1397
  except Exception as e:
1398
  print('Failed to delete %s. Reason: %s' % (file_path, e))
1399
 
1400
+
1401
  AGREE_POP_SCRIPTS = """
1402
  async () => {
1403
  alert("To use our service, you are required to agree to the following terms:\\nYou must not use our service to generate any harmful, unethical or illegal content that violates local and international laws, including but not limited to hate speech, violence and deception.\\nThe service may collect user dialogue data for performance improvement, and reserves the right to distribute it under CC-BY or similar license. So do not enter any personal information!");
 
1414
  stop_strings: str = "[STOP],<s>,</s>",
1415
  current_time: Optional[float] = None,
1416
  ):
1417
+ """This is only for debug purpose"""
1418
  files = files if isinstance(files, list) else [files]
1419
  print(files)
1420
  filenames = [f.name for f in files]
 
1440
 
1441
 
1442
  def validate_file_item(filename, index, item: Dict[str, str]):
1443
+ """
1444
+ check safety for items in files
1445
+ """
1446
  message = item['prompt'].strip()
1447
 
1448
  if len(message) == 0:
 
1450
 
1451
  message_safety = safety_check(message, history=None)
1452
  if message_safety is not None:
1453
+ raise gr.Error(f'Prompt {index} invalid: {message_safety}')
1454
 
1455
  tokenizer = llm.get_tokenizer() if llm is not None else None
1456
  if tokenizer is None or len(tokenizer.encode(message, add_special_tokens=False)) >= BATCH_INFER_MAX_PROMPT_TOKENS:
 
1474
  validate_file_item(fname, i, x)
1475
 
1476
  all_items.extend(items)
1477
+
1478
  if len(all_items) > BATCH_INFER_MAX_ITEMS:
1479
  raise gr.Error(f"Num samples {len(all_items)} > {BATCH_INFER_MAX_ITEMS} allowed.")
1480
 
1481
+ return all_items, filenames
1482
 
1483
 
1484
+ def remove_gradio_cache(exclude_names=None):
1485
+ """remove gradio cache to avoid flooding"""
1486
  import shutil
1487
  for root, dirs, files in os.walk('/tmp/gradio/'):
1488
  for f in files:
1489
+ # if not any(f in ef for ef in except_files):
1490
+ if exclude_names is None or not any(ef in f for ef in exclude_names):
1491
+ print(f'Remove: {f}')
1492
+ os.unlink(os.path.join(root, f))
1493
+ # for d in dirs:
1494
+ # # if not any(d in ef for ef in except_files):
1495
+ # if exclude_names is None or not any(ef in d for ef in exclude_names):
1496
+ # print(f'Remove d: {d}')
1497
+ # shutil.rmtree(os.path.join(root, d))
1498
 
1499
 
1500
  def maybe_upload_batch_set(pred_json_path):
1501
  global LOG_FILE, DATA_SET_REPO_PATH, SAVE_LOGS
1502
 
1503
+ if SAVE_LOGS and DATA_SET_REPO_PATH != "":
1504
  try:
1505
  from huggingface_hub import upload_file
1506
  path_in_repo = "misc/" + os.path.basename(pred_json_path).replace(".json", f'.{time.time()}.json')
 
1529
  system_prompt: Optional[str] = SYSTEM_PROMPT_1
1530
  ):
1531
  """
1532
+ Handle file upload batch inference
1533
 
1534
  """
1535
  global LOG_FILE, LOG_PATH, DEBUG, llm, RES_PRINTED
 
1552
  frequency_penalty = float(frequency_penalty)
1553
  max_tokens = int(max_tokens)
1554
 
1555
+ all_items, filenames = read_validate_json_files(files)
1556
 
1557
  # remove all items in /tmp/gradio/
1558
+ remove_gradio_cache(exclude_names=['upload_chat.json', 'upload_few_shot.json'])
 
1559
 
1560
  if prompt_mode == 'chat':
1561
  prompt_format_fn = llama_chat_multiturn_sys_input_seq_constructor
 
1610
 
1611
 
1612
  # BATCH_INFER_MAX_ITEMS
1613
+ FILE_UPLOAD_DESCRIPTION = f"""Upload JSON file as list of dict with < {BATCH_INFER_MAX_ITEMS} items, \
1614
+ each item has `prompt` key. We put guardrails to enhance safety, so do not input any harmful content or personal information! Re-upload the file after every submit. See the examples below.
1615
  ```
1616
+ [ {{"id": 0, "prompt": "Hello world"}} , {{"id": 1, "prompt": "Hi there?"}}]
1617
  ```
1618
  """
1619
 
1620
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1621
  def launch():
1622
  global demo, llm, DEBUG, LOG_FILE
1623
  model_desc = MODEL_DESC
 
1715
 
1716
  if ENABLE_BATCH_INFER:
1717
 
1718
+ demo_file_upload = gr.Interface(
1719
  batch_inference,
1720
  inputs=[
1721
  gr.File(file_count='single', file_types=['json']),
1722
  gr.Radio(["chat", "few-shot"], value='chat', label="Chat or Few-shot mode", info="Chat's output more user-friendly, Few-shot's output more consistent with few-shot patterns."),
1723
+ gr.Number(value=temperature, label='Temperature', info="Higher -> more random"),
1724
+ gr.Number(value=max_tokens, label='Max tokens', info='Increase if want more generation'),
1725
+ gr.Number(value=frequence_penalty, label='Frequency penalty', info='> 0 encourage new tokens over repeated tokens'),
1726
+ gr.Number(value=presence_penalty, label='Presence penalty', info='> 0 encourage new tokens, < 0 encourage existing tokens'),
1727
+ gr.Textbox(value="[STOP],[END],<s>,</s>", label='Stop strings', info='Comma-separated string to stop generation only in FEW-SHOT mode', lines=1),
1728
  gr.Number(value=0, label='current_time', visible=False),
1729
  ],
1730
  outputs=[
1731
  # "file",
1732
  gr.File(label="Generated file"),
 
1733
  # "json"
1734
+ gr.JSON(label='Example outputs (display 2 samples)')
1735
  ],
1736
+ description=FILE_UPLOAD_DESCRIPTION,
1737
+ allow_flagging=False,
1738
+ examples=[
1739
+ ["upload_chat.json", "chat", 0.2, 1024, 0.5, 0, "[STOP],[END],<s>,</s>"],
1740
+ ["upload_few_shot.json", "few-shot", 0.2, 128, 0.5, 0, "[STOP],[END],<s>,</s>,\\n"]
1741
+ ],
1742
+ # cache_examples=True,
1743
  )
1744
 
 
1745
  demo_chat = gr.ChatInterface(
1746
  response_fn,
1747
  chatbot=ChatBot(
 
1769
  # gr.Textbox(value=sys_prompt, label='System prompt', lines=8)
1770
  ],
1771
  )
1772
+ demo = CustomTabbedInterface(
1773
+ interface_list=[demo_chat, demo_file_upload],
1774
  tab_names=["Chat Interface", "Batch Inference"],
1775
  title=f"{model_title}",
1776
  description=f"{model_desc}",
 
1836
 
1837
  if __name__ == "__main__":
1838
  main()
1839
+