Files changed (1) hide show
  1. app.py +40 -19
app.py CHANGED
@@ -8,11 +8,12 @@ st.text('Tips: ')
8
  st.text("* WeLM不是一个直接的对话机器人,而是一个补全用户输入信息的生成模型")
9
  st.text("* 修改Prompt可以更多参考 https://welm.weixin.qq.com/docs/introduction/")
10
  st.text("* 你的输入可能会被我们拼接在预设的prompt尾部后再发送给API")
11
- st.text("* 在每个任务的下方我们展示了该任务请求API时完整的参数(包含完整的prompt)")
12
 
13
 
14
 
15
  class Task(str, Enum):
 
16
  DIALOG_JOURNAL = "对话(Elon musk)"
17
  QA = "问答"
18
  COPY= "文案生成"
@@ -20,7 +21,7 @@ class Task(str, Enum):
20
  READING_COMPREHENSION = "阅读理解"
21
  TRANSLATE = "翻译"
22
  COMPLETION = "文章续写"
23
- FREE = "自由任务"
24
 
25
 
26
  task_value2type = {v.value: v.name for v in Task}
@@ -32,6 +33,17 @@ task_type = st.selectbox(
32
  task_type = task_value2type[task_type]
33
 
34
  task2prompt_pre = {
 
 
 
 
 
 
 
 
 
 
 
35
  Task.READING_COMPREHENSION: """请阅读文章后根据文章内容给出问题的答案。
36
  文章:中国空间技术研究院(China Academy of Space Technology,简称CAST)隶属于中国航天科技集团公司,是中国空间技术的主要研究中心和航天器研制、生产基地,成立于1968年2月20日。下设10个研究所和一个工厂。现任院长为杨保华,院党委书记为李开民。1970年4月24日,中国空间技术研究院成功研制并发射了中国首颗人造地球卫星东方红一号。2003年10月,神舟五号载人飞船载人飞行取得成功。2005年,神舟六号载人飞船实现多人多天的太空飞行。截至2005年,中国空间技术研究院研制并成功发射了68颗不同类型的人造卫星、4艘无人试验飞船和2艘载人飞船,涵盖通信广播卫星、返回式遥感>卫星、地球资源卫星、气象卫星、科学探测与技术试验卫星、导航定位卫星和载人航天器等领域。
37
  问题:中国空间技术研究院在哪年成立?
@@ -89,7 +101,6 @@ Elon Musk: 51岁,但我说我有40岁人的精力。在健康方面最重要
89
  """,
90
  Task.COMPLETION: """
91
  """,
92
- Task.FREE: ""
93
  }
94
 
95
  task2prompt_end = {
@@ -114,13 +125,15 @@ Elon Musk:""",
114
  prompt_fix = task2prompt_pre[Task[task_type]]
115
  prompt_user = task2prompt_end[Task[task_type]]
116
 
117
- user_input = st.text_area('你的输入(最终完整输入请见下方 API 请求内容)', value=prompt_user, height=180)
118
  all_input = prompt_fix + user_input
 
 
119
  all_input = all_input.rstrip('\\n')
120
 
121
 
122
  with st.expander("配置"):
123
- stop_tokens = ""
124
  def cut_message(answer: str):
125
  end = []
126
  for etk in stop_tokens:
@@ -137,14 +150,22 @@ with st.expander("配置"):
137
  default_top_p, default_temperature, default_n, default_tokens = 0.0, 0.0, 1, 60
138
  elif task_type == 'COMPLETION':
139
  default_top_p, default_temperature, default_n, default_tokens = 0.95, 0.85, 1, 150
 
 
140
  else:
141
  default_top_p, default_temperature, default_n, default_tokens = 0.95, 0.85, 3, 64
142
 
143
- model = st.selectbox("model", ["medium", "large", "xl"], index=2)
 
 
 
 
144
  top_p = st.slider('top p', 0.0, 1.0, default_top_p)
145
  top_k = st.slider('top k', 0, 100, 0)
146
  temperature = st.slider('temperature', 0.0, 1.0, default_temperature)
147
  n = st.slider('n', 1, 5, default_n)
 
 
148
  max_tokens = st.slider('max tokens', 4, 512, default_tokens)
149
 
150
  if st.checkbox("使用换行符作为截断", value=False):
@@ -183,19 +204,19 @@ def completion():
183
  st.error(f"生成结果出错:{str(e)}")
184
 
185
 
186
- code_str = """
187
- post_json = {{
188
- 'prompt': '{all_input}',
189
- 'model': '{model}',
190
- 'max_tokens': {max_tokens},
191
- 'temperature': {temperature},
192
- 'top_p': {top_p},
193
- 'top_k': {top_k},
194
- 'n': {n},
195
- "stop": '{stop_tokens}',
196
- }}
197
- """.format(all_input=all_input,model=model,max_tokens=max_tokens,temperature=temperature, top_p=top_p,top_k=top_k,n=n,stop_tokens=stop_tokens)
198
- st.code(code_str)
199
 
200
  if st.button('立即生成'):
201
  completion()
 
8
  st.text("* WeLM不是一个直接的对话机器人,而是一个补全用户输入信息的生成模型")
9
  st.text("* 修改Prompt可以更多参考 https://welm.weixin.qq.com/docs/introduction/")
10
  st.text("* 你的输入可能会被我们拼接在预设的prompt尾部后再发送给API")
11
+ # st.text("* 在每个任务的下方我们展示了该任务请求API时完整的参数(包含完整的prompt)")
12
 
13
 
14
 
15
  class Task(str, Enum):
16
+ FREE = "自由任务"
17
  DIALOG_JOURNAL = "对话(Elon musk)"
18
  QA = "问答"
19
  COPY= "文案生成"
 
21
  READING_COMPREHENSION = "阅读理解"
22
  TRANSLATE = "翻译"
23
  COMPLETION = "文章续写"
24
+
25
 
26
 
27
  task_value2type = {v.value: v.name for v in Task}
 
33
  task_type = task_value2type[task_type]
34
 
35
  task2prompt_pre = {
36
+ Task.FREE: """
37
+
38
+ Human: 你是谁<eot>
39
+
40
+ Assistant: 我是 Assistant,一个由 WeChat 训练的 AI 语言模型。
41
+ 我对人类活动有全面的了解,可以为你解答各种领域的问题。
42
+ 我会在最短的篇幅内给你信息量最丰富的回答。
43
+ 我的回答是理性、准确的。
44
+ 在我认为合适的时候,我可能会用英文回复你。<eot>
45
+
46
+ Human: """,
47
  Task.READING_COMPREHENSION: """请阅读文章后根据文章内容给出问题的答案。
48
  文章:中国空间技术研究院(China Academy of Space Technology,简称CAST)隶属于中国航天科技集团公司,是中国空间技术的主要研究中心和航天器研制、生产基地,成立于1968年2月20日。下设10个研究所和一个工厂。现任院长为杨保华,院党委书记为李开民。1970年4月24日,中国空间技术研究院成功研制并发射了中国首颗人造地球卫星东方红一号。2003年10月,神舟五号载人飞船载人飞行取得成功。2005年,神舟六号载人飞船实现多人多天的太空飞行。截至2005年,中国空间技术研究院研制并成功发射了68颗不同类型的人造卫星、4艘无人试验飞船和2艘载人飞船,涵盖通信广播卫星、返回式遥感>卫星、地球资源卫星、气象卫星、科学探测与技术试验卫星、导航定位卫星和载人航天器等领域。
49
  问题:中国空间技术研究院在哪年成立?
 
101
  """,
102
  Task.COMPLETION: """
103
  """,
 
104
  }
105
 
106
  task2prompt_end = {
 
125
  prompt_fix = task2prompt_pre[Task[task_type]]
126
  prompt_user = task2prompt_end[Task[task_type]]
127
 
128
+ user_input = st.text_area('你的输入', value=prompt_user, height=150)
129
  all_input = prompt_fix + user_input
130
+ if Task[task_type] == Task.FREE:
131
+ all_input = prompt_fix + user_input + "<eot>\n\nAssistant: "
132
  all_input = all_input.rstrip('\\n')
133
 
134
 
135
  with st.expander("配置"):
136
+ stop_tokens = ["<eot>", "\n\nHuman"]
137
  def cut_message(answer: str):
138
  end = []
139
  for etk in stop_tokens:
 
150
  default_top_p, default_temperature, default_n, default_tokens = 0.0, 0.0, 1, 60
151
  elif task_type == 'COMPLETION':
152
  default_top_p, default_temperature, default_n, default_tokens = 0.95, 0.85, 1, 150
153
+ elif task_type == 'FREE':
154
+ default_top_p, default_temperature, default_n, default_tokens = 0.95, 0.85, 1, 300
155
  else:
156
  default_top_p, default_temperature, default_n, default_tokens = 0.95, 0.85, 3, 64
157
 
158
+ if task_type == 'FREE':
159
+ model_list = ["xl-answer"]
160
+ else:
161
+ model_list = ["xl", "large", "medium"]
162
+ model = st.selectbox("model", model_list, index=0)
163
  top_p = st.slider('top p', 0.0, 1.0, default_top_p)
164
  top_k = st.slider('top k', 0, 100, 0)
165
  temperature = st.slider('temperature', 0.0, 1.0, default_temperature)
166
  n = st.slider('n', 1, 5, default_n)
167
+ if task_type == 'FREE':
168
+ n = 1
169
  max_tokens = st.slider('max tokens', 4, 512, default_tokens)
170
 
171
  if st.checkbox("使用换行符作为截断", value=False):
 
204
  st.error(f"生成结果出错:{str(e)}")
205
 
206
 
207
+ # code_str = """
208
+ # post_json = {{
209
+ # 'prompt': '{all_input}',
210
+ # 'model': '{model}',
211
+ # 'max_tokens': {max_tokens},
212
+ # 'temperature': {temperature},
213
+ # 'top_p': {top_p},
214
+ # 'top_k': {top_k},
215
+ # 'n': {n},
216
+ # "stop": '{stop_tokens}',
217
+ # }}
218
+ # """.format(all_input=all_input,model=model,max_tokens=max_tokens,temperature=temperature, top_p=top_p,top_k=top_k,n=n,stop_tokens=stop_tokens)
219
+ # st.code(code_str)
220
 
221
  if st.button('立即生成'):
222
  completion()