File size: 8,730 Bytes
aa77a85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6bd2504
aa77a85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
# import gradio as gr
import gradio
# import lmdb
# import base64
# import io
# import random
# import time
import json
import copy
# import sqlite3
from urllib.parse import urljoin
import openai

from app_js import api_key__get_from_browser, api_key__save_to_browser, saved_prompts_refresh_btn__click_js, selected_saved_prompt_title__change_js, saved_prompts_delete_btn__click_js, saved_prompts_save_btn__click_js, copy_prompt__click_js, paste_prompt__click_js, chat_copy_history_btn__click_js, chat_copy_history_md_btn__click_js, api_key_refresh_btn__click_js, api_key_save_btn__click_js

from functions import sequential_chat_fn, make_history_file_fn, on_click_send_btn, clear_history, copy_history, update_saved_prompt_titles, save_prompt, load_saved_prompt

introduction = """<center><h2>ChatGPT 数据蒸馏助手</h2></center>
"""


css = """
.table-wrap .cell-wrap input {min-width:80%}
#api-key-textbox textarea {filter:blur(8px); transition: filter 0.25s}
#api-key-textbox textarea:focus {filter:none}
#chat-log-md hr {margin-top: 1rem; margin-bottom: 1rem;}
#chat-markdown-wrap-box {max-height:80vh; overflow: auto !important;}
"""
with gradio.Blocks(title="ChatGPT 批处理", css=css) as demo:

    with gradio.Accordion("说明", open=True):
        gradio.Markdown(introduction)

    with gradio.Accordion("基本设置", open=False):
        system_prompt_enabled = gradio.Checkbox(label='是否使用系统全局提示语', info='是否要以“系统”身份,给 ChatGPT 描述任务?', value=True)
        # 系统提示
        system_prompt = gradio.Textbox(label='系统级全局提示语', info='以“系统”身份,给 ChatGPT 描述任务', value='你是一个医院导诊员,患者给你医院科室名称,请根据医院科室名称介绍一下该科室的作用,可以治疗的疾病,跟其它哪些科室联系比较紧密从而协助患者疾病的救治。请注意:你应该返回科室的作用、可以治疗的疾病、跟其它哪些科室联系比较紧密从而协助患者疾病的救治,而不要返回多余的内容,否则用户所使用的程序将会出错,给用户带来严重的损失。')
        # 用户消息模板
        user_message_template = gradio.Textbox(label='用户消息模板', info='要批量发送的消息的模板', value='科室名称:```___```')
        with gradio.Row():
            # 用户消息模板中的替换区
            user_message_template_mask = gradio.Textbox(label='模板占位符', info='消息模板中需要被替换的部分,可以是正则表达式', value='___')
            # 用户消息模板中的替换区是正则吗
            user_message_template_mask_is_regex = gradio.Checkbox(label='模板占位符是正则吗', info='模板占位符是不是正则表达式?', value=False)
        # 用户消息替换区清单文本
        user_message_list_text = gradio.Textbox(label='用户消息列表', info='所有待发送的消息', value='全科医学 心内科 感染病科 血液科 内分泌科 呼吸科 肾脏科 消化内科 风湿免疫科 肿瘤科 神经内科')
        with gradio.Row():
            # 用户消息替换区清单分隔符
            user_message_list_text_splitter = gradio.Textbox(label='用户消息分隔符', info='用于分割用户消息列表的分隔符,如逗号(`,`)、换行符(`\\n`)等,也可以是正则表达式,此处默认空格', value='\\s+')
            # 用户消息替换区清单分隔符是正则吗
            user_message_list_text_splitter_is_regex = gradio.Checkbox(label='分隔符是正则吗', info='用户消息分隔符是不是正则表达式?', value=True)
        # 历史记录条数
        history_prompt_num = gradio.Slider(label="发送历史记录条数", info='每次发生消息时,同时携带多少条先前的历史记录(以便 ChatGPT 了解上下文)', value=0, minimum=0, maximum=12000)

        # load_config_from_browser = gradio.Button("🔄 从浏览器加载配置")
        # save_config_to_browser = gradio.Button("💾 将配置保存到浏览器")
        # export_config_to_file = gradio.Button("📤 将配置导出为文件")

    # 更多参数
    with gradio.Accordion("更多参数", open=False):
        # 时间间隔
        sleep_base = gradio.Number(label='时间间隔 ms', value=700)
        # 时间间隔浮动
        sleep_rand = gradio.Number(label='时间间隔浮动 ms', value=200)
        # 那些参数
        prop_stream = gradio.Checkbox(label="流式传输 stream", value=True)
        prop_model = gradio.Textbox(label="模型 model", value="gpt-3.5-turbo")
        prop_temperature = gradio.Slider(label="temperature", value=0.7, minimum=0, maximum=2)
        prop_top_p = gradio.Slider(label="top_p", value=1, minimum=0, maximum=1)
        prop_choices_num = gradio.Slider(label="choices num(n)", value=1, minimum=1, maximum=20)
        prop_max_tokens = gradio.Slider(label="max_tokens", value=-1, minimum=-1, maximum=4096)
        prop_presence_penalty = gradio.Slider(label="presence_penalty", value=0, minimum=-2, maximum=2)
        prop_frequency_penalty = gradio.Slider(label="frequency_penalty", value=0, minimum=-2, maximum=2)
        prop_logit_bias = gradio.Textbox(label="logit_bias", visible=False)
    pass

    # API-Key
    token_text = gradio.Textbox(visible=False)
    with gradio.Row():
        with gradio.Column(scale=10, min_width=100):
            api_key_text = gradio.Textbox(label="OpenAI-APIkey", placeholder="sk-...", elem_id="api-key-textbox",  value='')
        # with gradio.Column(scale=1, min_width=100):
        #     api_key_load_btn = gradio.Button("🔄 从浏览器本地存储加载")
        #     api_key_load_btn.click(
        #         None,
        #         inputs=[],
        #         outputs=[api_key_text, token_text],
        #         _js=api_key__get_from_browser,
        #     )
        # with gradio.Column(scale=1, min_width=100):
        #     api_key_save_btn = gradio.Button("💾 保存到浏览器本地存储")
        #     api_key_save_btn.click(
        #         None,
        #         inputs=[api_key_text, token_text],
        #         outputs=[api_key_text, token_text],
        #         _js=api_key__save_to_browser,
        #     )
        pass
    pass

    # 开始执行按钮
    start_btn = gradio.Button(value='开始!')

    with gradio.Accordion(label="数据记录", elem_id='chat-markdown-wrap-box'):
        # 输出区域(隐藏状态)
        history = gradio.State(value=[])
        # 输出区域(md渲染)
        history_md_stable = gradio.Markdown(value="用户")
        history_md_stream = gradio.Markdown(value="助手")

    with gradio.Accordion("状态"):
        tips = gradio.Markdown(value="待命")

    # 中止执行按钮
    stop_btn = gradio.Button(value='中止!')

    with gradio.Accordion("下载数据", open=True):
        # gradio.Markdown("(暂时无法下载,可能是 Hugging Face 的限制,之后更新)")
        make_file_btn = gradio.Button(value='生成文件')
        with gradio.Row(visible=False) as file_row:
            # 下载区域(json文件)
            history_file_json = gradio.File(label='Json 下载', interactive=False)
            # 下载区域(md文件)
            history_file_md = gradio.File(label='Markdown 下载', interactive=False)
        pass
    pass


    make_file_btn.click(
        fn=make_history_file_fn,
        inputs=[history],
        outputs=[history_file_json, history_file_md, file_row],
    )


    start_event = start_btn.click(
        fn=sequential_chat_fn,
        inputs=[
            history,

            system_prompt_enabled,
            system_prompt,
            user_message_template,
            user_message_template_mask,
            user_message_template_mask_is_regex,
            user_message_list_text,
            user_message_list_text_splitter,
            user_message_list_text_splitter_is_regex,
            history_prompt_num,

            api_key_text, token_text,

            sleep_base,
            sleep_rand,
            prop_stream,
            prop_model,
            prop_temperature,
            prop_top_p,
            prop_choices_num,
            prop_max_tokens,
            prop_presence_penalty,
            prop_frequency_penalty,
            prop_logit_bias,
        ],
        outputs=[
            history,
            history_md_stable,
            history_md_stream,
            tips,
            file_row,
        ],
    )
    stop_btn.click(
        fn=None,
        inputs=[],
        outputs=[],
        cancels=[start_event],
    )


if __name__ == "__main__":
    demo.queue(concurrency_count=200).launch()