binary-husky commited on
Commit
403dd2f
·
1 Parent(s): 3f635bc

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +34 -23
main.py CHANGED
@@ -1,13 +1,15 @@
1
  import os; os.environ['no_proxy'] = '*' # 避免代理网络产生意外污染
2
- import gradio as gr
3
  from predict import predict
4
- from toolbox import format_io, find_free_port
5
 
6
  # 建议您复制一个config_private.py放自己的秘密, 如API和代理网址, 避免不小心传github被别人看到
7
- from config_private import proxies, WEB_PORT, LLM_MODEL
 
8
 
9
  # 如果WEB_PORT是-1, 则随机选取WEB端口
10
  PORT = find_free_port() if WEB_PORT <= 0 else WEB_PORT
 
11
 
12
  initial_prompt = "Serve me as a writing and programming assistant."
13
  title_html = """<h1 align="center">ChatGPT 学术优化</h1>"""
@@ -15,7 +17,7 @@ title_html = """<h1 align="center">ChatGPT 学术优化</h1>"""
15
  # 问询记录, python 版本建议3.9+(越新越好)
16
  import logging
17
  os.makedirs('gpt_log', exist_ok=True)
18
- try:logging.basicConfig(filename='gpt_log/chat_secrets.log', level=logging.INFO, encoding='utf-8')
19
  except:logging.basicConfig(filename='gpt_log/chat_secrets.log', level=logging.INFO)
20
  print('所有问询记录将自动保存在本地目录./gpt_log/chat_secrets.log, 请注意自我隐私保护哦!')
21
 
@@ -24,7 +26,7 @@ from functional import get_functionals
24
  functional = get_functionals()
25
 
26
  # 对一些丧心病狂的实验性功能模块进行测试
27
- from functional_crazy import get_crazy_functionals, on_file_uploaded, on_report_generated
28
  crazy_functional = get_crazy_functionals()
29
 
30
  # 处理markdown文本格式的转变
@@ -34,6 +36,7 @@ gr.Chatbot.postprocess = format_io
34
  from theme import adjust_theme
35
  set_theme = adjust_theme()
36
 
 
37
  with gr.Blocks(theme=set_theme, analytics_enabled=False) as demo:
38
  gr.HTML(title_html)
39
  with gr.Row():
@@ -42,14 +45,15 @@ with gr.Blocks(theme=set_theme, analytics_enabled=False) as demo:
42
  chatbot.style(height=1000)
43
  chatbot.style()
44
  history = gr.State([])
45
- TRUE = gr.State(True)
46
- FALSE = gr.State(False)
47
  with gr.Column(scale=1):
48
  with gr.Row():
49
  with gr.Column(scale=12):
50
  txt = gr.Textbox(show_label=False, placeholder="Input question here.").style(container=False)
51
  with gr.Column(scale=1):
52
- submitBtn = gr.Button("提交", variant="primary")
 
 
 
53
  with gr.Row():
54
  from check_proxy import check_proxy
55
  statusDisplay = gr.Markdown(f"Tip: 按Enter提交, 按Shift+Enter换行. \nNetwork: {check_proxy(proxies)}\nModel: {LLM_MODEL}")
@@ -67,36 +71,43 @@ with gr.Blocks(theme=set_theme, analytics_enabled=False) as demo:
67
  gr.Markdown("上传本地文件供上面的实验性功能调用.")
68
  with gr.Row():
69
  file_upload = gr.Files(label='任何文件,但推荐上传压缩文件(zip, tar)', file_count="multiple")
70
-
71
- systemPromptTxt = gr.Textbox(show_label=True, placeholder=f"System Prompt", label="System prompt", value=initial_prompt).style(container=True)
72
- #inputs, top_p, temperature, top_k, repetition_penalty
73
  with gr.Accordion("arguments", open=False):
74
  top_p = gr.Slider(minimum=-0, maximum=1.0, value=1.0, step=0.01,interactive=True, label="Top-p (nucleus sampling)",)
75
- temperature = gr.Slider(minimum=-0, maximum=5.0, value=1.0, step=0.01, interactive=True, label="Temperature",)
 
 
 
76
 
77
- txt.submit(predict, [txt, top_p, temperature, chatbot, history, systemPromptTxt], [chatbot, history, statusDisplay])
78
- submitBtn.click(predict, [txt, top_p, temperature, chatbot, history, systemPromptTxt], [chatbot, history, statusDisplay], show_progress=True)
 
 
 
79
  for k in functional:
80
- functional[k]["Button"].click(predict,
81
- [txt, top_p, temperature, chatbot, history, systemPromptTxt, TRUE, gr.State(k)], [chatbot, history, statusDisplay], show_progress=True)
 
82
  file_upload.upload(on_file_uploaded, [file_upload, chatbot, txt], [chatbot, txt])
83
  for k in crazy_functional:
84
- click_handle = crazy_functional[k]["Button"].click(crazy_functional[k]["Function"],
85
- [txt, top_p, temperature, chatbot, history, systemPromptTxt, gr.State(PORT)], [chatbot, history, statusDisplay]
86
  )
87
  try: click_handle.then(on_report_generated, [file_upload, chatbot], [file_upload, chatbot])
88
  except: pass
 
 
89
 
90
-
91
- # 延迟函数, 做一些准备工作, 最后尝试打开浏览器
92
  def auto_opentab_delay():
93
  import threading, webbrowser, time
94
  print(f"URL http://localhost:{PORT}")
95
- def open(): time.sleep(2)
96
- webbrowser.open_new_tab(f'http://localhost:{PORT}')
 
97
  t = threading.Thread(target=open)
98
  t.daemon = True; t.start()
99
 
100
  auto_opentab_delay()
101
  demo.title = "ChatGPT 学术优化"
102
- demo.queue().launch(server_name="0.0.0.0", share=True, server_port=PORT)
 
1
  import os; os.environ['no_proxy'] = '*' # 避免代理网络产生意外污染
2
+ import gradio as gr
3
  from predict import predict
4
+ from toolbox import format_io, find_free_port, on_file_uploaded, on_report_generated
5
 
6
  # 建议您复制一个config_private.py放自己的秘密, 如API和代理网址, 避免不小心传github被别人看到
7
+ try: from config_private import proxies, WEB_PORT, LLM_MODEL, CONCURRENT_COUNT, AUTHENTICATION
8
+ except: from config import proxies, WEB_PORT, LLM_MODEL, CONCURRENT_COUNT, AUTHENTICATION
9
 
10
  # 如果WEB_PORT是-1, 则随机选取WEB端口
11
  PORT = find_free_port() if WEB_PORT <= 0 else WEB_PORT
12
+ AUTHENTICATION = None if AUTHENTICATION == [] else AUTHENTICATION
13
 
14
  initial_prompt = "Serve me as a writing and programming assistant."
15
  title_html = """<h1 align="center">ChatGPT 学术优化</h1>"""
 
17
  # 问询记录, python 版本建议3.9+(越新越好)
18
  import logging
19
  os.makedirs('gpt_log', exist_ok=True)
20
+ try:logging.basicConfig(filename='gpt_log/chat_secrets.log', level=logging.INFO, encoding='utf-8')
21
  except:logging.basicConfig(filename='gpt_log/chat_secrets.log', level=logging.INFO)
22
  print('所有问询记录将自动保存在本地目录./gpt_log/chat_secrets.log, 请注意自我隐私保护哦!')
23
 
 
26
  functional = get_functionals()
27
 
28
  # 对一些丧心病狂的实验性功能模块进行测试
29
+ from functional_crazy import get_crazy_functionals
30
  crazy_functional = get_crazy_functionals()
31
 
32
  # 处理markdown文本格式的转变
 
36
  from theme import adjust_theme
37
  set_theme = adjust_theme()
38
 
39
+ cancel_handles = []
40
  with gr.Blocks(theme=set_theme, analytics_enabled=False) as demo:
41
  gr.HTML(title_html)
42
  with gr.Row():
 
45
  chatbot.style(height=1000)
46
  chatbot.style()
47
  history = gr.State([])
 
 
48
  with gr.Column(scale=1):
49
  with gr.Row():
50
  with gr.Column(scale=12):
51
  txt = gr.Textbox(show_label=False, placeholder="Input question here.").style(container=False)
52
  with gr.Column(scale=1):
53
+ with gr.Row():
54
+ resetBtn = gr.Button("重置", variant="secondary")
55
+ submitBtn = gr.Button("提交", variant="primary")
56
+ stopBtn = gr.Button("停止", variant="stop")
57
  with gr.Row():
58
  from check_proxy import check_proxy
59
  statusDisplay = gr.Markdown(f"Tip: 按Enter提交, 按Shift+Enter换行. \nNetwork: {check_proxy(proxies)}\nModel: {LLM_MODEL}")
 
71
  gr.Markdown("上传本地文件供上面的实验性功能调用.")
72
  with gr.Row():
73
  file_upload = gr.Files(label='任何文件,但推荐上传压缩文件(zip, tar)', file_count="multiple")
74
+ system_prompt = gr.Textbox(show_label=True, placeholder=f"System Prompt", label="System prompt", value=initial_prompt).style(container=True)
 
 
75
  with gr.Accordion("arguments", open=False):
76
  top_p = gr.Slider(minimum=-0, maximum=1.0, value=1.0, step=0.01,interactive=True, label="Top-p (nucleus sampling)",)
77
+ temperature = gr.Slider(minimum=-0, maximum=2.0, value=1.0, step=0.01, interactive=True, label="Temperature",)
78
+
79
+ predict_args = dict(fn=predict, inputs=[txt, top_p, temperature, chatbot, history, system_prompt], outputs=[chatbot, history, statusDisplay], show_progress=True)
80
+ empty_txt_args = dict(fn=lambda: "", inputs=[], outputs=[txt]) # 用于在提交后清空输入栏
81
 
82
+ cancel_handles.append(txt.submit(**predict_args))
83
+ # txt.submit(**empty_txt_args) 在提交后清空输入栏
84
+ cancel_handles.append(submitBtn.click(**predict_args))
85
+ # submitBtn.click(**empty_txt_args) 在提交后清空输入栏
86
+ resetBtn.click(lambda: ([], [], "已重置"), None, [chatbot, history, statusDisplay])
87
  for k in functional:
88
+ click_handle = functional[k]["Button"].click(predict,
89
+ [txt, top_p, temperature, chatbot, history, system_prompt, gr.State(True), gr.State(k)], [chatbot, history, statusDisplay], show_progress=True)
90
+ cancel_handles.append(click_handle)
91
  file_upload.upload(on_file_uploaded, [file_upload, chatbot, txt], [chatbot, txt])
92
  for k in crazy_functional:
93
+ click_handle = crazy_functional[k]["Button"].click(crazy_functional[k]["Function"],
94
+ [txt, top_p, temperature, chatbot, history, system_prompt, gr.State(PORT)], [chatbot, history, statusDisplay]
95
  )
96
  try: click_handle.then(on_report_generated, [file_upload, chatbot], [file_upload, chatbot])
97
  except: pass
98
+ cancel_handles.append(click_handle)
99
+ stopBtn.click(fn=None, inputs=None, outputs=None, cancels=cancel_handles)
100
 
101
+ # gradio的inbrowser触发不太稳定,回滚代码到原始的浏览器打开函数
 
102
  def auto_opentab_delay():
103
  import threading, webbrowser, time
104
  print(f"URL http://localhost:{PORT}")
105
+ def open():
106
+ time.sleep(2)
107
+ webbrowser.open_new_tab(f'http://localhost:{PORT}')
108
  t = threading.Thread(target=open)
109
  t.daemon = True; t.start()
110
 
111
  auto_opentab_delay()
112
  demo.title = "ChatGPT 学术优化"
113
+ demo.queue(concurrency_count=CONCURRENT_COUNT).launch(server_name="0.0.0.0", share=True, server_port=PORT, auth=AUTHENTICATION)