Spaces:
Runtime error
Runtime error
#!/usr/bin/env python | |
# -*- encoding: utf-8 -*- | |
''' | |
@Time : 2023/09/22 17:43:35 | |
@Author : zoeyxiong | |
@File : chatgpt_bot.py | |
@Desc : 调用chatGPT类 | |
''' | |
import os | |
import json | |
import openai | |
import gradio as gr | |
default_model = "gpt-3.5-turbo" | |
class ChatGPT: | |
def __init__(self, model ,init_system={"role": "system", "content": "你是一个AI助手"}, save_message=False, ): | |
self.messages = [] | |
self.init_system = init_system | |
self.model = model | |
self.messages.append(init_system) | |
# 开启此项,须告知用户 | |
self.save_message = save_message | |
self.filename="./user_messages.json" | |
def ask_gpt(self): | |
rsp = openai.ChatCompletion.create( | |
model=self.model, | |
messages=self.messages | |
) | |
return rsp.get("choices")[0]["message"]["content"] | |
def get_response(self, question): | |
""" 调用openai接口, 获取回答 | |
""" | |
# 用户的问题加入到message | |
self.messages.append({"role": "user", "content": question}) | |
# 问chatgpt问题的答案 | |
rsp = openai.ChatCompletion.create( | |
model=self.model, | |
messages=self.messages, | |
) | |
answer = rsp.get("choices")[0]["message"]["content"] | |
# 得到的答案加入message,多轮对话的历史信息 | |
self.messages.append({"role": "assistant", "content": answer}) | |
return answer | |
def clean_history(self): | |
""" 清空历史信息 | |
""" | |
self.messages.clear() | |
self.messages.append(self.init_system) | |
openai.api_key = "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" | |
MODEL_NAME = 'gpt-3.5-turbo' | |
# 自定义system | |
INIT_MSG = {"role": "system", "content": "你是一个资深算法工程师."} | |
# 设置端口号,默认7560,遇冲突可自定义 | |
SERVER_PORT = 7560 | |
# 调用gpt的bot | |
chatgpt = ChatGPT(MODEL_NAME, INIT_MSG) | |
def predict(input, chatbot): | |
""" 调用openai接口,获取答案 | |
""" | |
chatbot.append((input, "")) | |
# 找chatgpt要答案 | |
response = chatgpt.get_response(input) | |
chatbot[-1] = (input, response) | |
return chatbot | |
def reset_user_input(): | |
return gr.update(value='') | |
def reset_state(): | |
chatgpt.clean_history() | |
return [] | |
def main(): | |
with gr.Blocks() as demo: | |
gr.HTML("""<h1 align="center">{}</h1>""".format(MODEL_NAME)) | |
# gradio的chatbot | |
chatbot = gr.Chatbot() | |
with gr.Row(): | |
with gr.Column(scale=4): | |
with gr.Column(scale=50): | |
user_input = gr.Textbox(show_label=False, placeholder="Input...", lines=10).style( | |
container=False) | |
with gr.Column(min_width=32, scale=1): | |
submitBtn = gr.Button("Submit", variant="primary") | |
with gr.Column(scale=1): | |
emptyBtn = gr.Button("Clear History") | |
# 提交问题 | |
submitBtn.click(predict, [user_input, chatbot], | |
[chatbot], show_progress=True) | |
submitBtn.click(reset_user_input, [], [user_input]) | |
# 清空历史对话 | |
emptyBtn.click(reset_state, outputs=[chatbot], show_progress=True) | |
demo.queue().launch(share=False, inbrowser=True, server_port=SERVER_PORT) | |
if __name__ == '__main__': | |
main() |