Spaces:
Running
Running
add files
Browse files- AESCipher.py +54 -0
- MongdbClient.py +95 -0
- OpenaiBot.py +60 -0
- README.md +13 -13
- app.py +76 -0
- lib/AESCipher.py +54 -0
- lib/MongdbClient.py +95 -0
- lib/OpenaiBot.py +60 -0
- offline/insert_user.py +24 -0
- server.py +126 -0
AESCipher.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
"""
|
3 |
+
@Author: Freshield
|
4 |
+
@Contact: yangyufresh@163.com
|
5 |
+
@File: AESCipher.py
|
6 |
+
@Time: 2023-03-05 22:55
|
7 |
+
@Last_update: 2023-03-05 22:55
|
8 |
+
@Desc: None
|
9 |
+
@==============================================@
|
10 |
+
@ _____ _ _ _ _ @
|
11 |
+
@ | __|___ ___ ___| |_|_|___| |_| | @
|
12 |
+
@ | __| _| -_|_ -| | | -_| | . | @
|
13 |
+
@ |__| |_| |___|___|_|_|_|___|_|___| @
|
14 |
+
@ Freshield @
|
15 |
+
@==============================================@
|
16 |
+
"""
|
17 |
+
from Crypto.Cipher import AES
|
18 |
+
import base64
|
19 |
+
|
20 |
+
# 加密函数
|
21 |
+
def aes_encrypt(key, data):
|
22 |
+
# 将key转换成16、24、32位的字符串,不足的以空格补齐
|
23 |
+
key = key.ljust(32, ' ')
|
24 |
+
# 将data转换成16的倍数,不足的以空格补齐
|
25 |
+
data = data.ljust(16 * (len(data) // 16 + 1), ' ')
|
26 |
+
# 进行加密
|
27 |
+
cipher = AES.new(key.encode('utf-8'), AES.MODE_ECB)
|
28 |
+
encrypted_data = cipher.encrypt(data.encode('utf-8'))
|
29 |
+
# 将加密后的数据进行base64编码
|
30 |
+
encrypted_data = base64.b64encode(encrypted_data).decode('utf-8')
|
31 |
+
return encrypted_data
|
32 |
+
|
33 |
+
# 解密函数
|
34 |
+
def aes_decrypt(key, encrypted_data):
|
35 |
+
# 将key转换成16、24、32位的字符串,不足的以空格补齐
|
36 |
+
key = key.ljust(32, ' ')
|
37 |
+
# 对加密后的数据进行base64解码
|
38 |
+
encrypted_data = base64.b64decode(encrypted_data)
|
39 |
+
# 进行解密
|
40 |
+
cipher = AES.new(key.encode('utf-8'), AES.MODE_ECB)
|
41 |
+
decrypted_data = cipher.decrypt(encrypted_data).decode('utf-8')
|
42 |
+
# 去除解密后的数据中的空格
|
43 |
+
decrypted_data = decrypted_data.strip()
|
44 |
+
return decrypted_data
|
45 |
+
|
46 |
+
|
47 |
+
# 测试
|
48 |
+
if __name__ == '__main__':
|
49 |
+
key = '1234567890123456345345'
|
50 |
+
data = 'Hello, world!'
|
51 |
+
encrypted_data = aes_encrypt(key, data)
|
52 |
+
print('加密后的数据:', encrypted_data)
|
53 |
+
decrypted_data = aes_decrypt(key, encrypted_data)
|
54 |
+
print('解密后的数据:', decrypted_data)
|
MongdbClient.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
"""
|
3 |
+
@Author: Freshield
|
4 |
+
@Contact: yangyufresh@163.com
|
5 |
+
@File: MongdbClient.py
|
6 |
+
@Time: 2023-03-03 20:25
|
7 |
+
@Last_update: 2023-03-03 20:25
|
8 |
+
@Desc: None
|
9 |
+
@==============================================@
|
10 |
+
@ _____ _ _ _ _ @
|
11 |
+
@ | __|___ ___ ___| |_|_|___| |_| | @
|
12 |
+
@ | __| _| -_|_ -| | | -_| | . | @
|
13 |
+
@ |__| |_| |___|___|_|_|_|___|_|___| @
|
14 |
+
@ Freshield @
|
15 |
+
@==============================================@
|
16 |
+
"""
|
17 |
+
import pymongo
|
18 |
+
from hashlib import sha256
|
19 |
+
|
20 |
+
|
21 |
+
class MongodbClient(object):
|
22 |
+
"""Mongodb客户端"""
|
23 |
+
def __init__(self):
|
24 |
+
self.myclient = pymongo.MongoClient("mongodb://localhost:27017/")
|
25 |
+
self.mydb = self.myclient["openai_bot"]
|
26 |
+
self.user_info = self.mydb['user_info']
|
27 |
+
self.user_history = self.mydb['user_history']
|
28 |
+
|
29 |
+
def insert_user(self, username, password):
|
30 |
+
"""离线添加用户"""
|
31 |
+
username = sha256(username.encode('utf8')).hexdigest()
|
32 |
+
password = sha256(password.encode('utf8')).hexdigest()
|
33 |
+
mydict = {'username': username, 'password': password}
|
34 |
+
_ = self.user_info.insert_one(mydict)
|
35 |
+
|
36 |
+
def check_user_exist(self, username, password):
|
37 |
+
"""检测用户是否存在"""
|
38 |
+
username = sha256(username.encode('utf8')).hexdigest()
|
39 |
+
password = sha256(password.encode('utf8')).hexdigest()
|
40 |
+
mydoc = self.user_info.find({'username': username, 'password': password}).limit(1)
|
41 |
+
res = [x for x in mydoc]
|
42 |
+
|
43 |
+
return len(res) >= 1
|
44 |
+
|
45 |
+
def update_user_access_token(self, username, access_token):
|
46 |
+
"""更新数据库的access_token以便后续使用"""
|
47 |
+
username = sha256(username.encode('utf8')).hexdigest()
|
48 |
+
# 先看是否有这个用户
|
49 |
+
mydoc = self.user_history.find({'username': username}).limit(1)
|
50 |
+
res = [x for x in mydoc]
|
51 |
+
# 如果没有则直接创建
|
52 |
+
if len(res) < 1:
|
53 |
+
mydict = {
|
54 |
+
'username': username, 'access_token': access_token,
|
55 |
+
'role': '你是ChatGPT,OpenAI训练的大规模语言模型,简明的回答用户的问题。', 'history': []}
|
56 |
+
_ = self.user_history.insert_one(mydict)
|
57 |
+
# 如果有则更新
|
58 |
+
else:
|
59 |
+
self.user_history.update_one({'username': username}, {'$set': {'access_token': access_token}})
|
60 |
+
|
61 |
+
def get_user_chat_history(self, access_token):
|
62 |
+
"""获取用户的聊天历史"""
|
63 |
+
mydoc = self.user_history.find({'access_token': access_token}).limit(1)
|
64 |
+
res = [x for x in mydoc]
|
65 |
+
history_str, history_list = '', []
|
66 |
+
role = '你是ChatGPT,OpenAI训练的大规模语言模型,简明的回答用户的问题。'
|
67 |
+
if len(res) >= 1:
|
68 |
+
# 遍历加到history中
|
69 |
+
history_list = res[0]['history']
|
70 |
+
role = res[0]['role']
|
71 |
+
for qus, ans in history_list[::-1]:
|
72 |
+
history_str += f'Q: {qus}\nA: {ans}\n'
|
73 |
+
|
74 |
+
return history_str, history_list, role
|
75 |
+
|
76 |
+
def update_user_chat_history(self, access_token, qus, ans):
|
77 |
+
"""更新用户的聊天历史"""
|
78 |
+
mydoc = self.user_history.find({'access_token': access_token}).limit(1)
|
79 |
+
res = [x for x in mydoc]
|
80 |
+
if len(res) >= 1:
|
81 |
+
self.user_history.update_one({'access_token': access_token}, {'$push': {'history': (qus, ans)}})
|
82 |
+
|
83 |
+
def delete_user_chat_history(self, access_token):
|
84 |
+
"""删除用户的聊天历史"""
|
85 |
+
mydoc = self.user_history.find({'access_token': access_token}).limit(1)
|
86 |
+
res = [x for x in mydoc]
|
87 |
+
if len(res) >= 1:
|
88 |
+
self.user_history.update_one({'access_token': access_token}, {'$set': {'history': []}})
|
89 |
+
|
90 |
+
def update_role(self, access_token, role):
|
91 |
+
"""更新用户的聊天历史"""
|
92 |
+
mydoc = self.user_history.find({'access_token': access_token}).limit(1)
|
93 |
+
res = [x for x in mydoc]
|
94 |
+
if len(res) >= 1:
|
95 |
+
self.user_history.update_one({'access_token': access_token}, {'$set': {'role': role}})
|
OpenaiBot.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
"""
|
3 |
+
@Author: Freshield
|
4 |
+
@Contact: yangyufresh@163.com
|
5 |
+
@File: OpenaiBot.py
|
6 |
+
@Time: 2023-03-03 17:47
|
7 |
+
@Last_update: 2023-03-03 17:47
|
8 |
+
@Desc: None
|
9 |
+
@==============================================@
|
10 |
+
@ _____ _ _ _ _ @
|
11 |
+
@ | __|___ ___ ___| |_|_|___| |_| | @
|
12 |
+
@ | __| _| -_|_ -| | | -_| | . | @
|
13 |
+
@ |__| |_| |___|___|_|_|_|___|_|___| @
|
14 |
+
@ Freshield @
|
15 |
+
@==============================================@
|
16 |
+
"""
|
17 |
+
import os
|
18 |
+
import openai
|
19 |
+
|
20 |
+
|
21 |
+
class OpenaiBot(object):
|
22 |
+
"""调用openai的机器人"""
|
23 |
+
def __init__(self, temperature=0.5):
|
24 |
+
openai.api_key = os.environ.get('OPENAI_API_KEY')
|
25 |
+
self.model_engine = "gpt-3.5-turbo"
|
26 |
+
self.temperature = temperature
|
27 |
+
|
28 |
+
def set_api_key(self, api_key):
|
29 |
+
"""设定api key"""
|
30 |
+
openai.api_key = api_key
|
31 |
+
|
32 |
+
def construct_message(self, role, new_msg, history_list, keep_history=3):
|
33 |
+
"""
|
34 |
+
构造message,这里history_list是一个list,每个元素是一个tuple
|
35 |
+
"""
|
36 |
+
msg_list = [{"role": "system", "content": role}]
|
37 |
+
history_list = history_list[-keep_history:]
|
38 |
+
for user, assistant in history_list:
|
39 |
+
msg_list.append({"role": "user", "content": user})
|
40 |
+
msg_list.append({"role": "assistant", "content": assistant})
|
41 |
+
msg_list.append({"role": "user", "content": new_msg})
|
42 |
+
|
43 |
+
return msg_list
|
44 |
+
|
45 |
+
def get_response(self, role, new_msg, history_list, keep_history=3):
|
46 |
+
"""
|
47 |
+
通过openai获取回复
|
48 |
+
"""
|
49 |
+
msg_list = self.construct_message(role, new_msg, history_list, keep_history)
|
50 |
+
response = openai.ChatCompletion.create(
|
51 |
+
model=self.model_engine, messages=msg_list,
|
52 |
+
temperature=self.temperature
|
53 |
+
)
|
54 |
+
content = response['choices'][0]['message']['content']
|
55 |
+
|
56 |
+
return content
|
57 |
+
|
58 |
+
|
59 |
+
if __name__ == '__main__':
|
60 |
+
openai_bot = OpenaiBot()
|
README.md
CHANGED
@@ -1,13 +1,13 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
|
|
1 |
+
# ChatGPT-gradio
|
2 |
+
# 简介
|
3 |
+
|
4 |
+
> 这是一个可以简单的把ChatGPT的API应用的前端网页,通过gradio进行构建。同时给出了简单的无需数据库的版本和加入数据库的两个不同的版本。
|
5 |
+
|
6 |
+
|
7 |
+
基于ChatGPT的[API](https://github.com/openai/openai-python) 接口进行调用。
|
8 |
+
|
9 |
+
app的版本是已经直接部署到huggingface space的版本,没有任何的状态存储所以不需要数据库的支持。
|
10 |
+
|
11 |
+
而server版本是使用gradio结合mongodb的实现方式,加入了对于gradio的access token的识别bing获取,对于想要使用gradio构建自己的应用的朋友有一定的参考价值。需要注意的是这里需要通过offline部分的代码提前加入用户。
|
12 |
+
|
13 |
+
有任何问题欢迎来骚扰,vx: freshield
|
app.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
"""
|
3 |
+
@Author: Freshield
|
4 |
+
@Contact: yangyufresh@163.com
|
5 |
+
@File: server_simple.py
|
6 |
+
@Time: 2023-03-09 22:39
|
7 |
+
@Last_update: 2023-03-09 22:39
|
8 |
+
@Desc: None
|
9 |
+
@==============================================@
|
10 |
+
@ _____ _ _ _ _ @
|
11 |
+
@ | __|___ ___ ___| |_|_|___| |_| | @
|
12 |
+
@ | __| _| -_|_ -| | | -_| | . | @
|
13 |
+
@ |__| |_| |___|___|_|_|_|___|_|___| @
|
14 |
+
@ Freshield @
|
15 |
+
@==============================================@
|
16 |
+
"""
|
17 |
+
import gradio as gr
|
18 |
+
from lib.OpenaiBot import OpenaiBot
|
19 |
+
|
20 |
+
openai_bot = OpenaiBot()
|
21 |
+
|
22 |
+
|
23 |
+
def ask_chatGPT(openai_key, role, new_msg, state, request: gr.Request):
|
24 |
+
"""向chatGPT提问"""
|
25 |
+
res_content = '对不起,服务器出错了,请稍后再试。'
|
26 |
+
res = [(new_msg, res_content)]
|
27 |
+
try:
|
28 |
+
openai_bot.set_api_key(openai_key)
|
29 |
+
res_content = openai_bot.get_response(role, new_msg, state)
|
30 |
+
res = [(new_msg, res_content)]
|
31 |
+
except Exception as e:
|
32 |
+
print(e)
|
33 |
+
finally:
|
34 |
+
state += res
|
35 |
+
res = state
|
36 |
+
|
37 |
+
return res, state
|
38 |
+
|
39 |
+
|
40 |
+
def clean_question(question):
|
41 |
+
"""清除问题"""
|
42 |
+
return ''
|
43 |
+
|
44 |
+
|
45 |
+
if __name__ == '__main__':
|
46 |
+
with gr.Blocks(title="尝试chatGPT对话", css="#maxheight {max-height: 390px} ") as demo:
|
47 |
+
state = gr.State([])
|
48 |
+
with gr.Column(variant='panel'):
|
49 |
+
# title
|
50 |
+
with gr.Row():
|
51 |
+
gr.Markdown("## 尝试chatGPT对话")
|
52 |
+
with gr.Row():
|
53 |
+
# left part
|
54 |
+
with gr.Column():
|
55 |
+
openai_key = gr.Textbox(
|
56 |
+
label='openai_key', placeholder='输入你openai的api key', type='password')
|
57 |
+
role_b = gr.Textbox(
|
58 |
+
label='请输入你设定的chatGPT的角色', lines=2,
|
59 |
+
value='你是ChatGPT,OpenAI训练的大规模语言模型,简明的回答用户的问题。')
|
60 |
+
question_b = gr.Textbox(
|
61 |
+
label='请输入你想要问的问题',
|
62 |
+
placeholder='输入你想提问的内容...',
|
63 |
+
lines=3
|
64 |
+
)
|
65 |
+
with gr.Row():
|
66 |
+
greet_btn = gr.Button('提交', variant="primary")
|
67 |
+
# right part
|
68 |
+
with gr.Column():
|
69 |
+
answer_b = gr.Chatbot(
|
70 |
+
label='chatGPT的问答', value=[(None, '请在这里提问')], elem_id='maxheight')
|
71 |
+
|
72 |
+
greet_btn.click(fn=ask_chatGPT, inputs=[openai_key, role_b, question_b, state], outputs=[answer_b, state])
|
73 |
+
greet_btn.click(fn=clean_question, inputs=[question_b], outputs=[question_b])
|
74 |
+
|
75 |
+
demo.launch(server_name='0.0.0.0', server_port=8080)
|
76 |
+
demo.close()
|
lib/AESCipher.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
"""
|
3 |
+
@Author: Freshield
|
4 |
+
@Contact: yangyufresh@163.com
|
5 |
+
@File: AESCipher.py
|
6 |
+
@Time: 2023-03-05 22:55
|
7 |
+
@Last_update: 2023-03-05 22:55
|
8 |
+
@Desc: None
|
9 |
+
@==============================================@
|
10 |
+
@ _____ _ _ _ _ @
|
11 |
+
@ | __|___ ___ ___| |_|_|___| |_| | @
|
12 |
+
@ | __| _| -_|_ -| | | -_| | . | @
|
13 |
+
@ |__| |_| |___|___|_|_|_|___|_|___| @
|
14 |
+
@ Freshield @
|
15 |
+
@==============================================@
|
16 |
+
"""
|
17 |
+
from Crypto.Cipher import AES
|
18 |
+
import base64
|
19 |
+
|
20 |
+
# 加密函数
|
21 |
+
def aes_encrypt(key, data):
|
22 |
+
# 将key转换成16、24、32位的字符串,不足的以空格补齐
|
23 |
+
key = key.ljust(32, ' ')
|
24 |
+
# 将data转换成16的倍数,不足的以空格补齐
|
25 |
+
data = data.ljust(16 * (len(data) // 16 + 1), ' ')
|
26 |
+
# 进行加密
|
27 |
+
cipher = AES.new(key.encode('utf-8'), AES.MODE_ECB)
|
28 |
+
encrypted_data = cipher.encrypt(data.encode('utf-8'))
|
29 |
+
# 将加密后的数据进行base64编码
|
30 |
+
encrypted_data = base64.b64encode(encrypted_data).decode('utf-8')
|
31 |
+
return encrypted_data
|
32 |
+
|
33 |
+
# 解密函数
|
34 |
+
def aes_decrypt(key, encrypted_data):
|
35 |
+
# 将key转换成16、24、32位的字符串,不足的以空格补齐
|
36 |
+
key = key.ljust(32, ' ')
|
37 |
+
# 对加密后的数据进行base64解码
|
38 |
+
encrypted_data = base64.b64decode(encrypted_data)
|
39 |
+
# 进行解密
|
40 |
+
cipher = AES.new(key.encode('utf-8'), AES.MODE_ECB)
|
41 |
+
decrypted_data = cipher.decrypt(encrypted_data).decode('utf-8')
|
42 |
+
# 去除解密后的数据中的空格
|
43 |
+
decrypted_data = decrypted_data.strip()
|
44 |
+
return decrypted_data
|
45 |
+
|
46 |
+
|
47 |
+
# 测试
|
48 |
+
if __name__ == '__main__':
|
49 |
+
key = '1234567890123456345345'
|
50 |
+
data = 'Hello, world!'
|
51 |
+
encrypted_data = aes_encrypt(key, data)
|
52 |
+
print('加密后的数据:', encrypted_data)
|
53 |
+
decrypted_data = aes_decrypt(key, encrypted_data)
|
54 |
+
print('解密后的数据:', decrypted_data)
|
lib/MongdbClient.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
"""
|
3 |
+
@Author: Freshield
|
4 |
+
@Contact: yangyufresh@163.com
|
5 |
+
@File: MongdbClient.py
|
6 |
+
@Time: 2023-03-03 20:25
|
7 |
+
@Last_update: 2023-03-03 20:25
|
8 |
+
@Desc: None
|
9 |
+
@==============================================@
|
10 |
+
@ _____ _ _ _ _ @
|
11 |
+
@ | __|___ ___ ___| |_|_|___| |_| | @
|
12 |
+
@ | __| _| -_|_ -| | | -_| | . | @
|
13 |
+
@ |__| |_| |___|___|_|_|_|___|_|___| @
|
14 |
+
@ Freshield @
|
15 |
+
@==============================================@
|
16 |
+
"""
|
17 |
+
import pymongo
|
18 |
+
from hashlib import sha256
|
19 |
+
|
20 |
+
|
21 |
+
class MongodbClient(object):
|
22 |
+
"""Mongodb客户端"""
|
23 |
+
def __init__(self):
|
24 |
+
self.myclient = pymongo.MongoClient("mongodb://localhost:27017/")
|
25 |
+
self.mydb = self.myclient["openai_bot"]
|
26 |
+
self.user_info = self.mydb['user_info']
|
27 |
+
self.user_history = self.mydb['user_history']
|
28 |
+
|
29 |
+
def insert_user(self, username, password):
|
30 |
+
"""离线添加用户"""
|
31 |
+
username = sha256(username.encode('utf8')).hexdigest()
|
32 |
+
password = sha256(password.encode('utf8')).hexdigest()
|
33 |
+
mydict = {'username': username, 'password': password}
|
34 |
+
_ = self.user_info.insert_one(mydict)
|
35 |
+
|
36 |
+
def check_user_exist(self, username, password):
|
37 |
+
"""检测用户是否存在"""
|
38 |
+
username = sha256(username.encode('utf8')).hexdigest()
|
39 |
+
password = sha256(password.encode('utf8')).hexdigest()
|
40 |
+
mydoc = self.user_info.find({'username': username, 'password': password}).limit(1)
|
41 |
+
res = [x for x in mydoc]
|
42 |
+
|
43 |
+
return len(res) >= 1
|
44 |
+
|
45 |
+
def update_user_access_token(self, username, access_token):
|
46 |
+
"""更新数据库的access_token以便后续使用"""
|
47 |
+
username = sha256(username.encode('utf8')).hexdigest()
|
48 |
+
# 先看是否有这个用户
|
49 |
+
mydoc = self.user_history.find({'username': username}).limit(1)
|
50 |
+
res = [x for x in mydoc]
|
51 |
+
# 如果没有则直接创建
|
52 |
+
if len(res) < 1:
|
53 |
+
mydict = {
|
54 |
+
'username': username, 'access_token': access_token,
|
55 |
+
'role': '你是ChatGPT,OpenAI训练的大规模语言模型,简明的回答用户的问题。', 'history': []}
|
56 |
+
_ = self.user_history.insert_one(mydict)
|
57 |
+
# 如果有则更新
|
58 |
+
else:
|
59 |
+
self.user_history.update_one({'username': username}, {'$set': {'access_token': access_token}})
|
60 |
+
|
61 |
+
def get_user_chat_history(self, access_token):
|
62 |
+
"""获取用户的聊天历史"""
|
63 |
+
mydoc = self.user_history.find({'access_token': access_token}).limit(1)
|
64 |
+
res = [x for x in mydoc]
|
65 |
+
history_str, history_list = '', []
|
66 |
+
role = '你是ChatGPT,OpenAI训练的大规模语言模型,简明的回答用户的问题。'
|
67 |
+
if len(res) >= 1:
|
68 |
+
# 遍历加到history中
|
69 |
+
history_list = res[0]['history']
|
70 |
+
role = res[0]['role']
|
71 |
+
for qus, ans in history_list[::-1]:
|
72 |
+
history_str += f'Q: {qus}\nA: {ans}\n'
|
73 |
+
|
74 |
+
return history_str, history_list, role
|
75 |
+
|
76 |
+
def update_user_chat_history(self, access_token, qus, ans):
|
77 |
+
"""更新用户的聊天历史"""
|
78 |
+
mydoc = self.user_history.find({'access_token': access_token}).limit(1)
|
79 |
+
res = [x for x in mydoc]
|
80 |
+
if len(res) >= 1:
|
81 |
+
self.user_history.update_one({'access_token': access_token}, {'$push': {'history': (qus, ans)}})
|
82 |
+
|
83 |
+
def delete_user_chat_history(self, access_token):
|
84 |
+
"""删除用户的聊天历史"""
|
85 |
+
mydoc = self.user_history.find({'access_token': access_token}).limit(1)
|
86 |
+
res = [x for x in mydoc]
|
87 |
+
if len(res) >= 1:
|
88 |
+
self.user_history.update_one({'access_token': access_token}, {'$set': {'history': []}})
|
89 |
+
|
90 |
+
def update_role(self, access_token, role):
|
91 |
+
"""更新用户的聊天历史"""
|
92 |
+
mydoc = self.user_history.find({'access_token': access_token}).limit(1)
|
93 |
+
res = [x for x in mydoc]
|
94 |
+
if len(res) >= 1:
|
95 |
+
self.user_history.update_one({'access_token': access_token}, {'$set': {'role': role}})
|
lib/OpenaiBot.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
"""
|
3 |
+
@Author: Freshield
|
4 |
+
@Contact: yangyufresh@163.com
|
5 |
+
@File: OpenaiBot.py
|
6 |
+
@Time: 2023-03-03 17:47
|
7 |
+
@Last_update: 2023-03-03 17:47
|
8 |
+
@Desc: None
|
9 |
+
@==============================================@
|
10 |
+
@ _____ _ _ _ _ @
|
11 |
+
@ | __|___ ___ ___| |_|_|___| |_| | @
|
12 |
+
@ | __| _| -_|_ -| | | -_| | . | @
|
13 |
+
@ |__| |_| |___|___|_|_|_|___|_|___| @
|
14 |
+
@ Freshield @
|
15 |
+
@==============================================@
|
16 |
+
"""
|
17 |
+
import os
|
18 |
+
import openai
|
19 |
+
|
20 |
+
|
21 |
+
class OpenaiBot(object):
|
22 |
+
"""调用openai的机器人"""
|
23 |
+
def __init__(self, temperature=0.5):
|
24 |
+
openai.api_key = os.environ.get('OPENAI_API_KEY')
|
25 |
+
self.model_engine = "gpt-3.5-turbo"
|
26 |
+
self.temperature = temperature
|
27 |
+
|
28 |
+
def set_api_key(self, api_key):
|
29 |
+
"""设定api key"""
|
30 |
+
openai.api_key = api_key
|
31 |
+
|
32 |
+
def construct_message(self, role, new_msg, history_list, keep_history=3):
|
33 |
+
"""
|
34 |
+
构造message,这里history_list是一个list,每个元素是一个tuple
|
35 |
+
"""
|
36 |
+
msg_list = [{"role": "system", "content": role}]
|
37 |
+
history_list = history_list[-keep_history:]
|
38 |
+
for user, assistant in history_list:
|
39 |
+
msg_list.append({"role": "user", "content": user})
|
40 |
+
msg_list.append({"role": "assistant", "content": assistant})
|
41 |
+
msg_list.append({"role": "user", "content": new_msg})
|
42 |
+
|
43 |
+
return msg_list
|
44 |
+
|
45 |
+
def get_response(self, role, new_msg, history_list, keep_history=3):
|
46 |
+
"""
|
47 |
+
通过openai获取回复
|
48 |
+
"""
|
49 |
+
msg_list = self.construct_message(role, new_msg, history_list, keep_history)
|
50 |
+
response = openai.ChatCompletion.create(
|
51 |
+
model=self.model_engine, messages=msg_list,
|
52 |
+
temperature=self.temperature
|
53 |
+
)
|
54 |
+
content = response['choices'][0]['message']['content']
|
55 |
+
|
56 |
+
return content
|
57 |
+
|
58 |
+
|
59 |
+
if __name__ == '__main__':
|
60 |
+
openai_bot = OpenaiBot()
|
offline/insert_user.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
"""
|
3 |
+
@Author: Freshield
|
4 |
+
@Contact: yangyufresh@163.com
|
5 |
+
@File: insert_user.py
|
6 |
+
@Time: 2023-03-09 22:35
|
7 |
+
@Last_update: 2023-03-09 22:35
|
8 |
+
@Desc: None
|
9 |
+
@==============================================@
|
10 |
+
@ _____ _ _ _ _ @
|
11 |
+
@ | __|___ ___ ___| |_|_|___| |_| | @
|
12 |
+
@ | __| _| -_|_ -| | | -_| | . | @
|
13 |
+
@ |__| |_| |___|___|_|_|_|___|_|___| @
|
14 |
+
@ Freshield @
|
15 |
+
@==============================================@
|
16 |
+
"""
|
17 |
+
from lib.MongdbClient import MongodbClient
|
18 |
+
|
19 |
+
|
20 |
+
if __name__ == '__main__':
|
21 |
+
# 离线添加用户
|
22 |
+
mongo_client = MongodbClient()
|
23 |
+
username, password = '', ''
|
24 |
+
mongo_client.insert_user(username, password)
|
server.py
ADDED
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
"""
|
3 |
+
@Author: Freshield
|
4 |
+
@Contact: yangyufresh@163.com
|
5 |
+
@File: b2_try_gradio.py
|
6 |
+
@Time: 2023-03-03 16:06
|
7 |
+
@Last_update: 2023-03-03 16:06
|
8 |
+
@Desc: None
|
9 |
+
@==============================================@
|
10 |
+
@ _____ _ _ _ _ @
|
11 |
+
@ | __|___ ___ ___| |_|_|___| |_| | @
|
12 |
+
@ | __| _| -_|_ -| | | -_| | . | @
|
13 |
+
@ |__| |_| |___|___|_|_|_|___|_|___| @
|
14 |
+
@ Freshield @
|
15 |
+
@==============================================@
|
16 |
+
"""
|
17 |
+
import gradio as gr
|
18 |
+
from lib.OpenaiBot import OpenaiBot
|
19 |
+
from lib.MongdbClient import MongodbClient
|
20 |
+
|
21 |
+
openai_bot = OpenaiBot()
|
22 |
+
mongo_client = MongodbClient()
|
23 |
+
|
24 |
+
|
25 |
+
def check_auth(username, passowrd):
|
26 |
+
return mongo_client.check_user_exist(username, passowrd)
|
27 |
+
|
28 |
+
|
29 |
+
def ask_chatGPT(role, new_msg, state, request: gr.Request):
|
30 |
+
"""向chatGPT提问"""
|
31 |
+
# 获取access_token
|
32 |
+
access_token = request.request.cookies['access-token-unsecure']
|
33 |
+
res_content = '对不起,服务器出错了,请稍后再试。'
|
34 |
+
res = [(new_msg, res_content)]
|
35 |
+
try:
|
36 |
+
res_content = openai_bot.get_response(role, new_msg, state)
|
37 |
+
res = [(new_msg, res_content)]
|
38 |
+
except Exception as e:
|
39 |
+
print(e)
|
40 |
+
finally:
|
41 |
+
state += res
|
42 |
+
res = state
|
43 |
+
|
44 |
+
# 更新history
|
45 |
+
mongo_client.update_user_chat_history(access_token, new_msg, res_content)
|
46 |
+
history, _, _ = mongo_client.get_user_chat_history(access_token)
|
47 |
+
|
48 |
+
return res, state, history
|
49 |
+
|
50 |
+
|
51 |
+
def clean_question(question):
|
52 |
+
"""清除问题"""
|
53 |
+
return ''
|
54 |
+
|
55 |
+
|
56 |
+
def clean_history(history, request: gr.Request):
|
57 |
+
"""清除历史记录"""
|
58 |
+
access_token = request.request.cookies['access-token-unsecure']
|
59 |
+
mongo_client.delete_user_chat_history(access_token)
|
60 |
+
history, _, _ = mongo_client.get_user_chat_history(access_token)
|
61 |
+
|
62 |
+
return history
|
63 |
+
|
64 |
+
|
65 |
+
def update_role(role, request: gr.Request):
|
66 |
+
"""更新角色"""
|
67 |
+
access_token = request.request.cookies['access-token-unsecure']
|
68 |
+
mongo_client.update_role(access_token, role)
|
69 |
+
|
70 |
+
return role
|
71 |
+
|
72 |
+
|
73 |
+
if __name__ == '__main__':
|
74 |
+
with gr.Blocks(title="尝试chatGPT对话", css="#maxheight {max-height: 390px} ") as demo:
|
75 |
+
state = gr.State([])
|
76 |
+
with gr.Column(variant='panel'):
|
77 |
+
# title
|
78 |
+
with gr.Row():
|
79 |
+
gr.Markdown("## 尝试chatGPT对话")
|
80 |
+
with gr.Row():
|
81 |
+
# left part
|
82 |
+
with gr.Column():
|
83 |
+
role_b = gr.Textbox(
|
84 |
+
label='请输入你设定的chatGPT的角色', lines=2,
|
85 |
+
value='你是ChatGPT,OpenAI训练的大规模语言模型,简明的回答用户的问题。')
|
86 |
+
question_b = gr.Textbox(
|
87 |
+
label='请输入你想要问的问题',
|
88 |
+
placeholder='输入你想提问的内容...',
|
89 |
+
lines=3
|
90 |
+
)
|
91 |
+
with gr.Row():
|
92 |
+
role_btn = gr.Button('更新角色')
|
93 |
+
greet_btn = gr.Button('提交', variant="primary")
|
94 |
+
with gr.Row():
|
95 |
+
clean_history_btn = gr.Button('清除历史记录')
|
96 |
+
# right part
|
97 |
+
with gr.Column():
|
98 |
+
answer_b = gr.Chatbot(
|
99 |
+
label='chatGPT的问答', value=[(None, '请在这里提问')], elem_id='maxheight')
|
100 |
+
with gr.Row():
|
101 |
+
history_b = gr.TextArea(
|
102 |
+
label='历史记录', interactive=False)
|
103 |
+
|
104 |
+
role_btn.click(fn=update_role, inputs=[role_b], outputs=[role_b])
|
105 |
+
greet_btn.click(fn=ask_chatGPT, inputs=[role_b, question_b, state], outputs=[answer_b, state, history_b])
|
106 |
+
greet_btn.click(fn=clean_question, inputs=[question_b], outputs=[question_b])
|
107 |
+
clean_history_btn.click(fn=clean_history, inputs=[history_b], outputs=[history_b])
|
108 |
+
|
109 |
+
def demo_load(request: gr.Request):
|
110 |
+
"""第一次进入demo时候运行的"""
|
111 |
+
# 更新用户的access_token
|
112 |
+
token_dict = demo.server_app.tokens
|
113 |
+
access_token = request.request.cookies['access-token-unsecure']
|
114 |
+
username = token_dict[access_token]
|
115 |
+
mongo_client.update_user_access_token(username, access_token)
|
116 |
+
|
117 |
+
# 获取用户的历史记录
|
118 |
+
history_str, history_list, role = mongo_client.get_user_chat_history(access_token)
|
119 |
+
|
120 |
+
return history_str, history_list[-10:], role, history_list[-10:]
|
121 |
+
|
122 |
+
demo.load(demo_load, None, [history_b, state, role_b, answer_b])
|
123 |
+
|
124 |
+
demo.launch(
|
125 |
+
auth=check_auth, auth_message='请输入给定的用户名和密码', server_name='0.0.0.0', server_port=8081)
|
126 |
+
demo.close()
|