Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python | |
| # -*- encoding: utf-8 -*- | |
| ''' | |
| @Time : 2023/09/22 17:43:35 | |
| @Author : zoeyxiong | |
| @File : chatgpt_bot.py | |
| @Desc : 调用chatGPT类 | |
| ''' | |
| import os | |
| import json | |
| import openai | |
| import gradio as gr | |
| default_model = "gpt-3.5-turbo" | |
| class ChatGPT: | |
| def __init__(self, model ,init_system={"role": "system", "content": "你是一个AI助手"}, save_message=False, ): | |
| self.messages = [] | |
| self.init_system = init_system | |
| self.model = model | |
| self.messages.append(init_system) | |
| # 开启此项,须告知用户 | |
| self.save_message = save_message | |
| self.filename="./user_messages.json" | |
| def ask_gpt(self): | |
| rsp = openai.ChatCompletion.create( | |
| model=self.model, | |
| messages=self.messages | |
| ) | |
| return rsp.get("choices")[0]["message"]["content"] | |
| def get_response(self, question): | |
| """ 调用openai接口, 获取回答 | |
| """ | |
| # 用户的问题加入到message | |
| self.messages.append({"role": "user", "content": question}) | |
| # 问chatgpt问题的答案 | |
| rsp = openai.ChatCompletion.create( | |
| model=self.model, | |
| messages=self.messages, | |
| ) | |
| answer = rsp.get("choices")[0]["message"]["content"] | |
| # 得到的答案加入message,多轮对话的历史信息 | |
| self.messages.append({"role": "assistant", "content": answer}) | |
| return answer | |
| def clean_history(self): | |
| """ 清空历史信息 | |
| """ | |
| self.messages.clear() | |
| self.messages.append(self.init_system) | |
| openai.api_key = "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" | |
| MODEL_NAME = 'gpt-3.5-turbo' | |
| # 自定义system | |
| INIT_MSG = {"role": "system", "content": "你是一个资深算法工程师."} | |
| # 设置端口号,默认7560,遇冲突可自定义 | |
| SERVER_PORT = 7560 | |
| # 调用gpt的bot | |
| chatgpt = ChatGPT(MODEL_NAME, INIT_MSG) | |
| def predict(input, chatbot): | |
| """ 调用openai接口,获取答案 | |
| """ | |
| chatbot.append((input, "")) | |
| # 找chatgpt要答案 | |
| response = chatgpt.get_response(input) | |
| chatbot[-1] = (input, response) | |
| return chatbot | |
| def reset_user_input(): | |
| return gr.update(value='') | |
| def reset_state(): | |
| chatgpt.clean_history() | |
| return [] | |
| def main(): | |
| with gr.Blocks() as demo: | |
| gr.HTML("""<h1 align="center">{}</h1>""".format(MODEL_NAME)) | |
| # gradio的chatbot | |
| chatbot = gr.Chatbot() | |
| with gr.Row(): | |
| with gr.Column(scale=4): | |
| with gr.Column(scale=50): | |
| user_input = gr.Textbox(show_label=False, placeholder="Input...", lines=10).style( | |
| container=False) | |
| with gr.Column(min_width=32, scale=1): | |
| submitBtn = gr.Button("Submit", variant="primary") | |
| with gr.Column(scale=1): | |
| emptyBtn = gr.Button("Clear History") | |
| # 提交问题 | |
| submitBtn.click(predict, [user_input, chatbot], | |
| [chatbot], show_progress=True) | |
| submitBtn.click(reset_user_input, [], [user_input]) | |
| # 清空历史对话 | |
| emptyBtn.click(reset_state, outputs=[chatbot], show_progress=True) | |
| demo.queue().launch(share=False, inbrowser=True, server_port=SERVER_PORT) | |
| if __name__ == '__main__': | |
| main() |