art9's picture
improve interface ui
6c32cc1
raw
history blame
6.88 kB
import streamlit as st
import requests
import openai
import logging
# 设置OpenAI API密钥
openai.api_key = st.secrets["OPENAI_API_KEY"]
WELM_SECRET = st.secrets["WELM_SECRET"]
# 准备一些 prompt 的例子
def examples():
st.write('''<style>
[data-testid="column"] {
min-width: 1rem !important;
}
</style>''', unsafe_allow_html=True)
columns = st.columns(4)
with columns[0]:
if st.button('个性对话生成'):
st.session_state['prompt'] = "李白,字太白,号⻘莲居士,又号“谪仙人”,唐代伟大的浪漫主义 诗人,被后人誉为“诗仙”。\n我:今天我们穿越时空连线李白,请问李白你爱喝酒吗? 李白:当然。花间一壶酒,独酌无相亲。举杯邀明月,对影成三人。 \n我:你觉得杜甫怎么样? \n李白:他很仰慕我,但他有属于自己的⻛采。 \n我:你为何能如此逍遥? \n李白:天生我材必有用,千金散尽还复来!\n我:你去过哪些地方?\n李白:"
with columns[1]:
if st.button('开放问题回答'):
st.session_state['prompt'] = '请根据所学知识回答下面这个问题\n问题:百年孤独的作者是?\n回答:加西亚·马尔克斯\n问题:二战转折点是?\n回答:'
with columns[2]:
if st.button('文本风格转换'):
st.session_state['prompt'] = "有这样一段文本,{医生微笑着递给小明棒棒糖,同时让小明服下了药。}\n改写这段话让它变得更加惊悚。{医生眼露凶光让小明服药,小明感到非常害怕}。\n\n有这样一段文本,{雨下得很大}\n改写这段话让它变得更加具体。{一霎时,雨点连成了线,大雨就像天塌了似的铺天盖地从空中倾泻下来。}。\n\n有这样一段文本,{王老师离开了电影院,外面已经天黑了}\n改写这段话让它包含更多电影信息。{这部电影比小王预想的时间要长,虽然口碑很好,但离开电影院时,小王还是有些失望。}\n\n有这样一段文本,{男人站在超市外面打电话}\n改写这段话来描述小丑。{男人站在马戏团外一边拿着气球一边打电话}\n\n有这样一段文本,{风铃声响起}\n改写这段话写的更加丰富。{我对这个风铃的感情是由它的铃声引起的。每当风吹来时,风铃发出非常动听的声音,听起来是那么乐观、豁达,像一个小女孩格格的笑声。}\n\n有这样一段文本,{我想家了}\n改写这段话包含更多悲伤的感情。{"
with columns[3]:
if st.button('文本续写'):
st.session_state['prompt'] = "中国地大物博,自然⻛光秀丽,大自然的⻤斧神工造就了许多动人心魄的美景,"
# 定义completion函数
def completion(model_engine, prompt, max_tokens, temperature, top_p, top_k, n, stop_tokens):
if model_engine == "davinci-003":
model = "text-davinci-003"
answer = openai.Completion.create(
model=model, prompt=prompt, temperature=temperature, max_tokens=max_tokens, top_p=top_p, n=n,
stop=[" Human:", " AI:"], frequency_penalty=0, presence_penalty=0.6,
)
for idx, choice in enumerate(answer['choices']):
text = choice['text']
st.success(f'生成结果#{idx}: ')
st.write(text)
st.json(answer, expanded=False)
elif model_engine == "WeLM":
resp = requests.post("https://welm.weixin.qq.com/v1/completions", json={
'prompt': prompt,
'max_tokens': max_tokens,
'temperature': temperature,
'top_p': top_p,
'top_k': top_k,
'n': n,
'model': 'xl',
"stop": stop_tokens,
}, headers={"Authorization": f"Bearer {WELM_SECRET}"})
answer = resp.json()
for idx, choice in enumerate(answer['choices']):
if choice.get("finish_reason", None) != "finished":
st.error(f'生成结果#{idx}出错: {choice["finish_reason"]}')
elif choice.get("text", None) is None:
st.error(f'生成结果#{idx}出错: internal error')
else:
text = choice.get("text", "")
# text = cut_message(text)
if len(text) == 0:
st.info(f'生成结果#{idx}: 结果为空,可能的原因:生成的第一个字符为stop字符,请合理配置prompt或stop。比如,在prompt后追加"某某:"')
else:
st.success(f'生成结果#{idx}: ')
st.write(text)
st.json(answer, expanded=False)
# Streamlit应用程序
def app():
# 左侧栏
st.sidebar.title("参数设置")
model_engine_list = st.sidebar.multiselect("请选择要对比的模型", ["davinci-003", "WeLM"], default=["davinci-003", "WeLM"])
default_top_p = 0.95
default_top_k = 0
default_temperature = 0.85
default_n = 3
default_tokens = 256
temperature = st.sidebar.slider("Temperature", 0.0, 1.0, default_temperature, 0.01)
top_p = st.sidebar.slider('Top p', 0.0, 1.0, default_top_p)
top_k = st.sidebar.slider('Top k', 0, 100, default_top_k)
n = st.sidebar.slider('n', 1, 5, default_n)
max_tokens = st.sidebar.slider('max tokens', 4, 512, default_tokens)
stop_tokens = ""
if st.sidebar.checkbox("使用换行符作为截断", value=True):
stop_tokens = "\n"
# 主界面
st.title("对比不同模型生成文本的效果")
st.text('Tips: ')
st.text("* Davinci是是GPT-3语言生成模型,可以一定程度上理解用户的指令")
st.text("* WeLM不是一个直接的对话机器人,而是一个补全用户输入信息的生成模型")
st.text("* 因此 prompt 需要经过一定的设计,才能有比较好的效果")
st.text("* 修改Prompt可以更多参考 https://welm.weixin.qq.com/docs/introduction/ 或者使用下方的例子")
examples()
prompt = st.text_area("请输入Prompt:", key="prompt")
if st.button("生成"):
if prompt.strip() == '':
st.error("请输入内容")
st.stop()
columns = st.columns(len(model_engine_list))
for col, model_engine in zip(columns, model_engine_list):
with col:
st.subheader(model_engine)
with st.spinner("正在生成中..."):
completion(model_engine, prompt, max_tokens, temperature, top_p, top_k, n, stop_tokens)
pf = prompt.replace('\n','\\n')
logging.info(f"n={n},T={temperature},top_p={top_p},top_k={top_k},token={max_tokens},m={model_engine_list},p={pf}")
if __name__ == '__main__':
st.set_page_config(
page_title="对比不同模型生成文本的效果", layout="wide", initial_sidebar_state="auto",
)
app()