youngtsai's picture
refactor
b7db810
raw
history blame contribute delete
No virus
7.79 kB
import gradio as gr
import json
import os
from openai import OpenAI
import re
from azure.cognitiveservices.speech import SpeechConfig, SpeechSynthesizer, AudioConfig
PASSWORD = os.environ['PASSWORD']
OPEN_AI_KEY = os.environ['OPEN_AI_KEY']
AZURE_REGION = os.environ['AZURE_REGION']
AZURE_API_KEY = os.environ['AZURE_API_KEY']
def validate_and_correct_chat(data, roles=["A", "B"], rounds=2):
"""
Corrects the chat data to ensure proper roles and number of rounds.
Parameters:
- data (list): The chat data list of dicts, e.g. [{"role": "A", "content": "Hi"}, ...]
- roles (list): The expected roles, default is ["A", "B"]
- rounds (int): The number of rounds expected
Returns:
- list: Corrected chat data
"""
# Validate role names
for item in data:
if item['role'] not in roles:
print(f"Invalid role '{item['role']}' detected. Correcting it.")
# We will change the role to the next expected role in the sequence.
prev_index = roles.index(data[data.index(item) - 1]['role'])
next_index = (prev_index + 1) % len(roles)
item['role'] = roles[next_index]
# Validate number of rounds
expected_entries = rounds * len(roles)
if len(data) > expected_entries:
print(f"Too many rounds detected. Trimming the chat to {rounds} rounds.")
data = data[:expected_entries]
return data
def extract_json_from_response(response_text):
# 使用正則表達式匹配 JSON 格式的對話
match = re.search(r'\[\s*\{.*?\}\s*\]', response_text, re.DOTALL)
if match:
json_str = match.group(0)
return json.loads(json_str)
else:
raise ValueError("JSON dialogue not found in the response.")
def create_chat_dialogue(rounds, role1, role1_gender, role2, role2_gender, theme, language, cefr_level):
client = OpenAI(api_key=OPEN_AI_KEY)
# 初始化對話
sentenses_count = int(rounds) * 2
sys_content = f"你是一個{language}家教,請用{language}生成對話"
prompt = f"您將進行一場以{theme}為主題的對話,請用 cefr_level:{cefr_level} 為對話的程度。{role1} (gender: {role1_gender}) 和{role2} (gender: {role2_gender})將是參與者。請依次交談{rounds}輪。(1輪對話的定義是 {role1}{role2} 各說一句話,總共 {sentenses_count} 句話。)以json格式儲存對話。並回傳對話JSON文件。格式為:[{{role:\"{role1}\", \"gender\": {role1_gender} , content: \".....\"}}, {{role:\"{role2}\", \"gender\": {role2_gender}, content: \".....\"}}]"
messages = [
{"role": "system", "content": sys_content},
{"role": "user", "content": prompt}
]
print("=====messages=====")
print(messages)
print("=====messages=====")
request_payload = {
"model": "gpt-4-1106-preview",
"messages": messages,
"max_tokens": int(500 * int(rounds)) # 設定一個較大的值,可根據需要調整
}
response = client.chat.completions.create(**request_payload)
print(response)
response_text = response.choices[0].message.content.strip()
extract_json = extract_json_from_response(response_text)
dialogue = validate_and_correct_chat(data=extract_json, roles=[role1, role2], rounds=rounds)
print(dialogue)
# 這裡直接返回JSON格式的對話,但考慮到這可能只是一個字符串,您可能還需要將它解析為一個Python對象
return dialogue
def generate_dialogue(rounds, method, role1, role1_gender, role2, role2_gender, theme, language, cefr_level):
if method == "auto":
dialogue = create_chat_dialogue(rounds, role1, role1_gender, role2, role2_gender, theme, language, cefr_level)
else:
dialogue = [{"role": role1, "gender": role1_gender, "content": "手動輸入文本 1"}, {"role": role2, "gender": role2_gender , "content": "手動輸入文本 2"}]
return dialogue
def main_function(password: str, theme: str, language: str, cefr_level: str, method: str, rounds: int, role1: str, role1_gender: str, role2: str, role2_gender: str):
if password != os.environ.get("PASSWORD", ""):
return "错误的密码,请重新输入。", ""
structured_dialogue = generate_dialogue(rounds, method, role1, role1_gender, role2, role2_gender, theme, language, cefr_level)
# Convert structured dialogue for Chatbot component to show "role1: content1" and "role2: content2" side by side
chatbot_dialogue = []
for i in range(0, len(structured_dialogue), 2): # We iterate with a step of 2 to take pairs
# Get the content for the two roles in the pair
role1_content = f"{structured_dialogue[i]['content']}"
role2_content = f"{structured_dialogue[i+1]['content']}" if i+1 < len(structured_dialogue) else ""
chatbot_dialogue.append((role1_content, role2_content))
# audio_path = dialogue_to_audio(structured_dialogue, role1_gender, role2_gender)
json_output = json.dumps({"dialogue": structured_dialogue}, ensure_ascii=False, indent=4)
# 儲存對話為 JSON 文件
file_name = "dialogue_output.txt"
with open(file_name, "w", encoding="utf-8") as f:
f.write(json_output)
return chatbot_dialogue, file_name, json_output
if __name__ == "__main__":
with gr.Blocks(theme=gr.themes.Soft()) as demo: # 使用 'light' 主题作为默认值
# Header 或其他组件可以在这里添加,如果有需要
with gr.Row():
with gr.Column(scale=2): # 2/3 的宽度
chat_output = gr.Chatbot(label="生成的對話")
json_file = gr.File(label="下載對話 JSON 文件")
json_textbox = gr.Textbox(readonly=True, label="對話 JSON 內容", lines=10)
with gr.Column(scale=1): # 1/3 的宽度
password = gr.Textbox(label="输入密码", type="password")
theme = gr.Textbox(label="對話主題") # 加入 theme 的輸入框,設定預設值為 '購物'
language = gr.Dropdown(choices=["中文", "英文"], label="語言")
cefr_level = gr.Dropdown(choices=["A1", "A2", "B1", "B2", "C1", "C2"], label="CEFR Level")
generation_mode = gr.Dropdown(choices=["auto", "manual"], label="生成方式")
rounds = gr.Slider(minimum=2, maximum=6, step=2, label="對話輪數")
role1_name = gr.Textbox(label="角色 1 名稱")
role1_gender = gr.Dropdown(choices=["male", "female"], label="角色 1 性別")
role2_name = gr.Textbox(label="角色 2 名稱")
role2_gender = gr.Dropdown(choices=["male", "female"], label="角色 2 性別")
# 在这里添加提交和清除按鈕
submit_button = gr.Button("Submit")
clear_button = gr.Button("Clear")
submit_button.click(
main_function,
[
password,
theme,
language,
cefr_level,
generation_mode,
rounds,
role1_name,
role1_gender,
role2_name,
role2_gender
],
[
chat_output,
json_file,
json_textbox
]
)
clear_button.click(lambda: [[],None,""], None, [chat_output, json_file, json_textbox], queue=False)
# 可以添加其他交互逻辑和按钮事件,如果有需要
demo.launch(inline=False, share=True)