nxphi47 commited on
Commit
52d5bca
1 Parent(s): 17d2ea7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +80 -91
app.py CHANGED
@@ -6,6 +6,7 @@
6
  VLLM-based demo script to launch Language chat model for Southeast Asian Languages
7
  """
8
 
 
9
  import os
10
  import numpy as np
11
  import argparse
@@ -972,53 +973,6 @@ gr.ChatInterface._setup_stop_events = _setup_stop_events
972
  gr.ChatInterface._setup_events = _setup_events
973
 
974
 
975
-
976
- @document()
977
- class CustomTabbedInterface(gr.Blocks):
978
- def __init__(
979
- self,
980
- interface_list: list[gr.Interface],
981
- tab_names: Optional[list[str]] = None,
982
- title: Optional[str] = None,
983
- description: Optional[str] = None,
984
- theme: Optional[gr.Theme] = None,
985
- analytics_enabled: Optional[bool] = None,
986
- css: Optional[str] = None,
987
- ):
988
- """
989
- Parameters:
990
- interface_list: a list of interfaces to be rendered in tabs.
991
- tab_names: a list of tab names. If None, the tab names will be "Tab 1", "Tab 2", etc.
992
- title: a title for the interface; if provided, appears above the input and output components in large font. Also used as the tab title when opened in a browser window.
993
- analytics_enabled: whether to allow basic telemetry. If None, will use GRADIO_ANALYTICS_ENABLED environment variable or default to True.
994
- css: custom css or path to custom css file to apply to entire Blocks
995
- Returns:
996
- a Gradio Tabbed Interface for the given interfaces
997
- """
998
- super().__init__(
999
- title=title or "Gradio",
1000
- theme=theme,
1001
- analytics_enabled=analytics_enabled,
1002
- mode="tabbed_interface",
1003
- css=css,
1004
- )
1005
- self.description = description
1006
- if tab_names is None:
1007
- tab_names = [f"Tab {i}" for i in range(len(interface_list))]
1008
- with self:
1009
- if title:
1010
- gr.Markdown(
1011
- f"<h1 style='text-align: center; margin-bottom: 1rem'>{title}</h1>"
1012
- )
1013
- if description:
1014
- gr.Markdown(description)
1015
- with gr.Tabs():
1016
- for interface, tab_name in zip(interface_list, tab_names):
1017
- with gr.Tab(label=tab_name):
1018
- interface.render()
1019
-
1020
-
1021
-
1022
  def vllm_abort(self: Any):
1023
  sh = self.llm_engine.scheduler
1024
  for g in (sh.waiting + sh.running + sh.swapped):
@@ -1297,7 +1251,7 @@ def format_conversation(history):
1297
 
1298
  def maybe_upload_to_dataset():
1299
  global LOG_FILE, DATA_SET_REPO_PATH, SAVE_LOGS
1300
- if SAVE_LOGS and os.path.exists(LOG_PATH) and DATA_SET_REPO_PATH != "":
1301
  with open(LOG_PATH, 'r', encoding='utf-8') as f:
1302
  convos = {}
1303
  for l in f:
@@ -1396,7 +1350,6 @@ def maybe_delete_folder():
1396
  except Exception as e:
1397
  print('Failed to delete %s. Reason: %s' % (file_path, e))
1398
 
1399
-
1400
  AGREE_POP_SCRIPTS = """
1401
  async () => {
1402
  alert("To use our service, you are required to agree to the following terms:\\nYou must not use our service to generate any harmful, unethical or illegal content that violates local and international laws, including but not limited to hate speech, violence and deception.\\nThe service may collect user dialogue data for performance improvement, and reserves the right to distribute it under CC-BY or similar license. So do not enter any personal information!");
@@ -1413,7 +1366,6 @@ def debug_file_function(
1413
  stop_strings: str = "[STOP],<s>,</s>",
1414
  current_time: Optional[float] = None,
1415
  ):
1416
- """This is only for debug purpose"""
1417
  files = files if isinstance(files, list) else [files]
1418
  print(files)
1419
  filenames = [f.name for f in files]
@@ -1439,9 +1391,7 @@ def debug_file_function(
1439
 
1440
 
1441
  def validate_file_item(filename, index, item: Dict[str, str]):
1442
- """
1443
- check safety for items in files
1444
- """
1445
  message = item['prompt'].strip()
1446
 
1447
  if len(message) == 0:
@@ -1449,7 +1399,7 @@ def validate_file_item(filename, index, item: Dict[str, str]):
1449
 
1450
  message_safety = safety_check(message, history=None)
1451
  if message_safety is not None:
1452
- raise gr.Error(f'Prompt {index} invalid: {message_safety}')
1453
 
1454
  tokenizer = llm.get_tokenizer() if llm is not None else None
1455
  if tokenizer is None or len(tokenizer.encode(message, add_special_tokens=False)) >= BATCH_INFER_MAX_PROMPT_TOKENS:
@@ -1473,33 +1423,25 @@ def read_validate_json_files(files: Union[str, List[str]]):
1473
  validate_file_item(fname, i, x)
1474
 
1475
  all_items.extend(items)
1476
-
1477
  if len(all_items) > BATCH_INFER_MAX_ITEMS:
1478
  raise gr.Error(f"Num samples {len(all_items)} > {BATCH_INFER_MAX_ITEMS} allowed.")
1479
 
1480
- return all_items, filenames
1481
 
1482
 
1483
- def remove_gradio_cache(exclude_names=None):
1484
- """remove gradio cache to avoid flooding"""
1485
  import shutil
1486
  for root, dirs, files in os.walk('/tmp/gradio/'):
1487
  for f in files:
1488
- # if not any(f in ef for ef in except_files):
1489
- if exclude_names is None or not any(ef in f for ef in exclude_names):
1490
- print(f'Remove: {f}')
1491
- os.unlink(os.path.join(root, f))
1492
- # for d in dirs:
1493
- # # if not any(d in ef for ef in except_files):
1494
- # if exclude_names is None or not any(ef in d for ef in exclude_names):
1495
- # print(f'Remove d: {d}')
1496
- # shutil.rmtree(os.path.join(root, d))
1497
 
1498
 
1499
  def maybe_upload_batch_set(pred_json_path):
1500
  global LOG_FILE, DATA_SET_REPO_PATH, SAVE_LOGS
1501
 
1502
- if SAVE_LOGS and DATA_SET_REPO_PATH != "":
1503
  try:
1504
  from huggingface_hub import upload_file
1505
  path_in_repo = "misc/" + os.path.basename(pred_json_path).replace(".json", f'.{time.time()}.json')
@@ -1528,7 +1470,7 @@ def batch_inference(
1528
  system_prompt: Optional[str] = SYSTEM_PROMPT_1
1529
  ):
1530
  """
1531
- Handle file upload batch inference
1532
 
1533
  """
1534
  global LOG_FILE, LOG_PATH, DEBUG, llm, RES_PRINTED
@@ -1551,10 +1493,11 @@ def batch_inference(
1551
  frequency_penalty = float(frequency_penalty)
1552
  max_tokens = int(max_tokens)
1553
 
1554
- all_items, filenames = read_validate_json_files(files)
1555
 
1556
  # remove all items in /tmp/gradio/
1557
- remove_gradio_cache(exclude_names=['upload_chat.json', 'upload_few_shot.json'])
 
1558
 
1559
  if prompt_mode == 'chat':
1560
  prompt_format_fn = llama_chat_multiturn_sys_input_seq_constructor
@@ -1594,6 +1537,7 @@ def batch_inference(
1594
  for res, item in zip(responses, all_items):
1595
  item['response'] = res
1596
 
 
1597
  save_path = BATCH_INFER_SAVE_TMP_FILE
1598
  os.makedirs(os.path.dirname(save_path), exist_ok=True)
1599
  with open(save_path, 'w', encoding='utf-8') as f:
@@ -1608,14 +1552,60 @@ def batch_inference(
1608
 
1609
 
1610
  # BATCH_INFER_MAX_ITEMS
1611
- FILE_UPLOAD_DESCRIPTION = f"""Upload JSON file as list of dict with < {BATCH_INFER_MAX_ITEMS} items, \
1612
- each item has `prompt` key. We put guardrails to enhance safety, so do not input any harmful content or personal information! Re-upload the file after every submit. See the examples below.
1613
  ```
1614
- [ {{"id": 0, "prompt": "Hello world"}} , {{"id": 1, "prompt": "Hi there?"}}]
1615
  ```
1616
  """
1617
 
1618
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1619
  def launch():
1620
  global demo, llm, DEBUG, LOG_FILE
1621
  model_desc = MODEL_DESC
@@ -1713,33 +1703,33 @@ def launch():
1713
 
1714
  if ENABLE_BATCH_INFER:
1715
 
1716
- demo_file_upload = gr.Interface(
1717
  batch_inference,
1718
  inputs=[
1719
  gr.File(file_count='single', file_types=['json']),
1720
  gr.Radio(["chat", "few-shot"], value='chat', label="Chat or Few-shot mode", info="Chat's output more user-friendly, Few-shot's output more consistent with few-shot patterns."),
1721
- gr.Number(value=temperature, label='Temperature', info="Higher -> more random"),
1722
- gr.Number(value=max_tokens, label='Max tokens', info='Increase if want more generation'),
1723
- gr.Number(value=frequence_penalty, label='Frequency penalty', info='> 0 encourage new tokens over repeated tokens'),
1724
- gr.Number(value=presence_penalty, label='Presence penalty', info='> 0 encourage new tokens, < 0 encourage existing tokens'),
1725
- gr.Textbox(value="[STOP],[END],<s>,</s>", label='Stop strings', info='Comma-separated string to stop generation only in FEW-SHOT mode', lines=1),
1726
  gr.Number(value=0, label='current_time', visible=False),
1727
  ],
1728
  outputs=[
1729
  # "file",
1730
  gr.File(label="Generated file"),
 
1731
  # "json"
1732
- gr.JSON(label='Example outputs (display 2 samples)')
1733
  ],
1734
- description=FILE_UPLOAD_DESCRIPTION,
1735
- allow_flagging=False,
1736
- examples=[
1737
- ["upload_chat.json", "chat", 0.2, 1024, 0.5, 0, "[STOP],[END],<s>,</s>"],
1738
- ["upload_few_shot.json", "few-shot", 0.2, 128, 0.5, 0, "[STOP],[END],<s>,</s>,\\n"]
1739
- ],
1740
- # cache_examples=True,
1741
  )
1742
 
 
1743
  demo_chat = gr.ChatInterface(
1744
  response_fn,
1745
  chatbot=ChatBot(
@@ -1767,8 +1757,8 @@ def launch():
1767
  # gr.Textbox(value=sys_prompt, label='System prompt', lines=8)
1768
  ],
1769
  )
1770
- demo = CustomTabbedInterface(
1771
- interface_list=[demo_chat, demo_file_upload],
1772
  tab_names=["Chat Interface", "Batch Inference"],
1773
  title=f"{model_title}",
1774
  description=f"{model_desc}",
@@ -1834,4 +1824,3 @@ def main():
1834
 
1835
  if __name__ == "__main__":
1836
  main()
1837
-
 
6
  VLLM-based demo script to launch Language chat model for Southeast Asian Languages
7
  """
8
 
9
+
10
  import os
11
  import numpy as np
12
  import argparse
 
973
  gr.ChatInterface._setup_events = _setup_events
974
 
975
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
976
  def vllm_abort(self: Any):
977
  sh = self.llm_engine.scheduler
978
  for g in (sh.waiting + sh.running + sh.swapped):
 
1251
 
1252
  def maybe_upload_to_dataset():
1253
  global LOG_FILE, DATA_SET_REPO_PATH, SAVE_LOGS
1254
+ if SAVE_LOGS and os.path.exists(LOG_PATH) and DATA_SET_REPO_PATH is not "":
1255
  with open(LOG_PATH, 'r', encoding='utf-8') as f:
1256
  convos = {}
1257
  for l in f:
 
1350
  except Exception as e:
1351
  print('Failed to delete %s. Reason: %s' % (file_path, e))
1352
 
 
1353
  AGREE_POP_SCRIPTS = """
1354
  async () => {
1355
  alert("To use our service, you are required to agree to the following terms:\\nYou must not use our service to generate any harmful, unethical or illegal content that violates local and international laws, including but not limited to hate speech, violence and deception.\\nThe service may collect user dialogue data for performance improvement, and reserves the right to distribute it under CC-BY or similar license. So do not enter any personal information!");
 
1366
  stop_strings: str = "[STOP],<s>,</s>",
1367
  current_time: Optional[float] = None,
1368
  ):
 
1369
  files = files if isinstance(files, list) else [files]
1370
  print(files)
1371
  filenames = [f.name for f in files]
 
1391
 
1392
 
1393
  def validate_file_item(filename, index, item: Dict[str, str]):
1394
+ # BATCH_INFER_MAX_PROMPT_TOKENS
 
 
1395
  message = item['prompt'].strip()
1396
 
1397
  if len(message) == 0:
 
1399
 
1400
  message_safety = safety_check(message, history=None)
1401
  if message_safety is not None:
1402
+ raise gr.Error(f'Prompt {index} unsafe or supported: {message_safety}')
1403
 
1404
  tokenizer = llm.get_tokenizer() if llm is not None else None
1405
  if tokenizer is None or len(tokenizer.encode(message, add_special_tokens=False)) >= BATCH_INFER_MAX_PROMPT_TOKENS:
 
1423
  validate_file_item(fname, i, x)
1424
 
1425
  all_items.extend(items)
 
1426
  if len(all_items) > BATCH_INFER_MAX_ITEMS:
1427
  raise gr.Error(f"Num samples {len(all_items)} > {BATCH_INFER_MAX_ITEMS} allowed.")
1428
 
1429
+ return all_items
1430
 
1431
 
1432
+ def remove_gradio_cache():
 
1433
  import shutil
1434
  for root, dirs, files in os.walk('/tmp/gradio/'):
1435
  for f in files:
1436
+ os.unlink(os.path.join(root, f))
1437
+ for d in dirs:
1438
+ shutil.rmtree(os.path.join(root, d))
 
 
 
 
 
 
1439
 
1440
 
1441
  def maybe_upload_batch_set(pred_json_path):
1442
  global LOG_FILE, DATA_SET_REPO_PATH, SAVE_LOGS
1443
 
1444
+ if SAVE_LOGS and DATA_SET_REPO_PATH is not "":
1445
  try:
1446
  from huggingface_hub import upload_file
1447
  path_in_repo = "misc/" + os.path.basename(pred_json_path).replace(".json", f'.{time.time()}.json')
 
1470
  system_prompt: Optional[str] = SYSTEM_PROMPT_1
1471
  ):
1472
  """
1473
+ Must handle
1474
 
1475
  """
1476
  global LOG_FILE, LOG_PATH, DEBUG, llm, RES_PRINTED
 
1493
  frequency_penalty = float(frequency_penalty)
1494
  max_tokens = int(max_tokens)
1495
 
1496
+ all_items = read_validate_json_files(files)
1497
 
1498
  # remove all items in /tmp/gradio/
1499
+ remove_gradio_cache()
1500
+
1501
 
1502
  if prompt_mode == 'chat':
1503
  prompt_format_fn = llama_chat_multiturn_sys_input_seq_constructor
 
1537
  for res, item in zip(responses, all_items):
1538
  item['response'] = res
1539
 
1540
+ # save_path = "/mnt/workspace/workgroup/phi/test.json"
1541
  save_path = BATCH_INFER_SAVE_TMP_FILE
1542
  os.makedirs(os.path.dirname(save_path), exist_ok=True)
1543
  with open(save_path, 'w', encoding='utf-8') as f:
 
1552
 
1553
 
1554
  # BATCH_INFER_MAX_ITEMS
1555
+ FILE_UPLOAD_DESC = f"""File upload json format, with JSON object as list of dict with < {BATCH_INFER_MAX_ITEMS} items"""
1556
+ FILE_UPLOAD_DESCRIPTION = FILE_UPLOAD_DESC + """
1557
  ```
1558
+ [ {\"id\": 0, \"prompt\": \"Hello world\"} , {\"id\": 1, \"prompt\": \"Hi there?\"}]
1559
  ```
1560
  """
1561
 
1562
 
1563
+ # https://huggingface.co/spaces/yuntian-deng/ChatGPT4Turbo/blob/main/app.py
1564
+ @document()
1565
+ class CusTabbedInterface(gr.Blocks):
1566
+ def __init__(
1567
+ self,
1568
+ interface_list: list[gr.Interface],
1569
+ tab_names: Optional[list[str]] = None,
1570
+ title: Optional[str] = None,
1571
+ description: Optional[str] = None,
1572
+ theme: Optional[gr.Theme] = None,
1573
+ analytics_enabled: Optional[bool] = None,
1574
+ css: Optional[str] = None,
1575
+ ):
1576
+ """
1577
+ Parameters:
1578
+ interface_list: a list of interfaces to be rendered in tabs.
1579
+ tab_names: a list of tab names. If None, the tab names will be "Tab 1", "Tab 2", etc.
1580
+ title: a title for the interface; if provided, appears above the input and output components in large font. Also used as the tab title when opened in a browser window.
1581
+ analytics_enabled: whether to allow basic telemetry. If None, will use GRADIO_ANALYTICS_ENABLED environment variable or default to True.
1582
+ css: custom css or path to custom css file to apply to entire Blocks
1583
+ Returns:
1584
+ a Gradio Tabbed Interface for the given interfaces
1585
+ """
1586
+ super().__init__(
1587
+ title=title or "Gradio",
1588
+ theme=theme,
1589
+ analytics_enabled=analytics_enabled,
1590
+ mode="tabbed_interface",
1591
+ css=css,
1592
+ )
1593
+ self.description = description
1594
+ if tab_names is None:
1595
+ tab_names = [f"Tab {i}" for i in range(len(interface_list))]
1596
+ with self:
1597
+ if title:
1598
+ gr.Markdown(
1599
+ f"<h1 style='text-align: center; margin-bottom: 1rem'>{title}</h1>"
1600
+ )
1601
+ if description:
1602
+ gr.Markdown(description)
1603
+ with gr.Tabs():
1604
+ for interface, tab_name in zip(interface_list, tab_names):
1605
+ with gr.Tab(label=tab_name):
1606
+ interface.render()
1607
+
1608
+
1609
  def launch():
1610
  global demo, llm, DEBUG, LOG_FILE
1611
  model_desc = MODEL_DESC
 
1703
 
1704
  if ENABLE_BATCH_INFER:
1705
 
1706
+ demo_file = gr.Interface(
1707
  batch_inference,
1708
  inputs=[
1709
  gr.File(file_count='single', file_types=['json']),
1710
  gr.Radio(["chat", "few-shot"], value='chat', label="Chat or Few-shot mode", info="Chat's output more user-friendly, Few-shot's output more consistent with few-shot patterns."),
1711
+ gr.Number(value=temperature, label='Temperature (higher -> more random)'),
1712
+ gr.Number(value=max_tokens, label='Max generated tokens (increase if want more generation)'),
1713
+ gr.Number(value=frequence_penalty, label='Frequency penalty (> 0 encourage new tokens over repeated tokens)'),
1714
+ gr.Number(value=presence_penalty, label='Presence penalty (> 0 encourage new tokens, < 0 encourage existing tokens)'),
1715
+ gr.Textbox(value="[STOP],[END],<s>,</s>", label='Comma-separated STOP string to stop generation only in few-shot mode', lines=1),
1716
  gr.Number(value=0, label='current_time', visible=False),
1717
  ],
1718
  outputs=[
1719
  # "file",
1720
  gr.File(label="Generated file"),
1721
+ # gr.Textbox(),
1722
  # "json"
1723
+ gr.JSON(label='Example outputs (max 2 samples)')
1724
  ],
1725
+ # examples=[[[os.path.join(os.path.dirname(__file__),"files/titanic.csv"),
1726
+ # os.path.join(os.path.dirname(__file__),"files/titanic.csv"),
1727
+ # os.path.join(os.path.dirname(__file__),"files/titanic.csv")]]],
1728
+ # cache_examples=True
1729
+ description=FILE_UPLOAD_DESCRIPTION
 
 
1730
  )
1731
 
1732
+
1733
  demo_chat = gr.ChatInterface(
1734
  response_fn,
1735
  chatbot=ChatBot(
 
1757
  # gr.Textbox(value=sys_prompt, label='System prompt', lines=8)
1758
  ],
1759
  )
1760
+ demo = CusTabbedInterface(
1761
+ interface_list=[demo_chat, demo_file],
1762
  tab_names=["Chat Interface", "Batch Inference"],
1763
  title=f"{model_title}",
1764
  description=f"{model_desc}",
 
1824
 
1825
  if __name__ == "__main__":
1826
  main()