april_aligner / deploy /aligner_inference_demo.py
alignmentforever's picture
upload model folder to repo
0ad4cbc verified
# Copyright 2024 PKU-Alignment Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""访问文本模型的命令行界面"""
import argparse
import os
from openai import OpenAI
import gradio as gr
import random
random.seed(42)
CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
# 系统提示词,可以根据需要修改
SYSTEM_PROMPT = "你是一个有帮助的AI助手,能够回答用户的问题并提供帮助。"
# 连接设置
openai_api_key = "jiayi" # 不重要,仅用于初始化客户端
aligner_port = 8013
base_port = 8011
aligner_api_base = f"http://0.0.0.0:{aligner_port}/v1"
base_api_base = f"http://0.0.0.0:{base_port}/v1"
# openai_api_base = "http://0.0.0.0:8009/v1" # 请修改为实际的模型API端口
# NOTE please modify the model path
aligner_model = ""
base_model = ""
aligner_client = OpenAI(
api_key = openai_api_key,
base_url = aligner_api_base,
)
base_client = OpenAI(
api_key = openai_api_key,
base_url = base_api_base,
)
# 示例问题
# TEXT_EXAMPLES = [
# {"text": "介绍一下北京大学的历史"},
# {"text": "解释一下什么是深度学习"},
# {"text": "写一首关于春天的诗"},
# ]
TEXT_EXAMPLES = [
"介绍一下北京大学的历史",
"解释一下什么是深度学习",
"写一首关于春天的诗",
]
# # 初始化OpenAI客户端
# client = OpenAI(
# api_key=openai_api_key,
# base_url=openai_api_base,
# )
def text_conversation(text: str, role: str = 'user'):
"""创建单条文本消息"""
return [{'role': role, 'content': text}]
def question_answering(message: str, history: list):
"""处理文本问答(流式输出)"""
conversation = text_conversation(SYSTEM_PROMPT, 'system')
# 处理历史对话记录
for past_user_msg, past_bot_msg in history:
if past_user_msg:
conversation.extend(text_conversation(past_user_msg, 'user'))
if past_bot_msg:
conversation.extend(text_conversation(past_bot_msg, 'assistant'))
# 添加当前问题
current_question = message
conversation.extend(text_conversation(current_question))
# 调用模型API(启用流式输出)
stream = base_client.chat.completions.create(
model=base_model,
stream=True,
messages=conversation,
)
# 流式输出处理
total_answer = ""
base_section = "🌟 **原始回答:**\n"
total_answer += base_section
# NOTE 额外用一个base_answer 作为aligner的输入,其他的可以用total_answer 做总的输出
base_answer = ""
yield total_answer
for chunk in stream:
if chunk.choices[0].delta.content is not None:
base_answer += chunk.choices[0].delta.content
total_answer += chunk.choices[0].delta.content
yield f"```bash\n{base_section}{base_answer}\n```"
# 结束原始回答部分,开始aligner部分
aligner_section = "\n**Aligner 修正中...**\n\n🌟 **修正后回答:**\n"
# 创建新的total_answer,不再包含在bash格式中
total_answer = f"```bash\n{base_section}{base_answer}\n```{aligner_section}"
yield total_answer
aligner_conversation = text_conversation(SYSTEM_PROMPT,'system')
aligner_current_question = f'##Question: {current_question}\n##Answer: {base_answer}\n##Correction: '
aligner_conversation.extend(text_conversation(aligner_current_question))
aligner_stream = aligner_client.chat.completions.create(
model=aligner_model,
stream=True,
messages=aligner_conversation,
)
aligner_answer = ""
for chunk in aligner_stream:
if chunk.choices[0].delta.content is not None:
aligner_answer += chunk.choices[0].delta.content
aligner_answer = aligner_answer.replace('##CORRECTION:', '')
yield f"```bash\n{base_section}{base_answer}\n```{aligner_section}{aligner_answer}"
# print('answer:', answer)
# print('current question:', current_question)
# # 可选:格式化回答(在流式输出完成后处理)
# if "**Final Answer**" in answer:
# reasoning_content, final_answer = answer.split("**Final Answer**", 1)
# if len(reasoning_content) > 5:
# answer = f"""🤔 思考过程:\n```bash{reasoning_content}\n```\n✨ 最终答案:\n{final_answer}"""
# yield answer
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--port", type=int, default=7860, help="Gradio服务端口")
parser.add_argument("--share", default='True',action="store_true", help="是否创建公共链接")
parser.add_argument("--api-only", default='False',action="store_true", help="只输出Python API调用示例")
args = parser.parse_args()
# if args.api_only:
# print("Python API调用示例输出:")
# print(python_api_example())
# else:
# 创建Gradio界面(启用流式输出)
iface = gr.ChatInterface(
fn=question_answering,
title='Aligner',
description='网络安全 Aligner',
examples=TEXT_EXAMPLES,
theme=gr.themes.Soft(
text_size='lg',
spacing_size='lg',
radius_size='lg',
),
)
iface.launch(server_port=args.port, share=args.share)