File size: 2,126 Bytes
88aba71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import json
import openai
from openai import OpenAI  # 导入 OpenAI 类

from tqdm import tqdm
from typing import List, Dict, cast # 导入 cast
from openai.types.chat import ChatCompletionMessageParam # 导入消息参数类型

from weclone.utils.config import load_config

config = load_config("web_demo")

config = {
    "default_prompt": config["default_system"],
    "model": "gpt-3.5-turbo",
    "history_len": 15,
}

config = type("Config", (object,), config)()

# 初始化 OpenAI 客户端
client = OpenAI(
    api_key="""sk-test""",
    base_url="http://127.0.0.1:8005/v1"
)


def handler_text(content: str, history: list, config):
    messages = [{"role": "system", "content": f"{config.default_prompt}"}]
    for item in history:
        messages.append(item)
    messages.append({"role": "user", "content": content})
    history.append({"role": "user", "content": content})
    try:
        # 使用新的 API 调用方式
        # 将 messages 转换为正确的类型
        typed_messages = cast(List[ChatCompletionMessageParam], messages)
        response = client.chat.completions.create(
            model=config.model,
            messages=typed_messages, # 传递转换后的列表
            max_tokens=50
        )
    except openai.APIError as e:
        history.pop()
        return "AI接口出错,请重试\n" + str(e)

    resp = str(response.choices[0].message.content) # type: ignore
    resp = resp.replace("\n ", "")
    history.append({"role": "assistant", "content": resp})
    return resp


def main():
    test_list = json.loads(open("dataset/test_data.json", "r", encoding="utf-8").read())["questions"]
    res = []
    for questions in tqdm(test_list, desc=" Testing..."):
        history = []
        for q in questions:
            handler_text(q, history=history, config=config)
        res.append(history)

    res_file = open("test_result-my.txt", "w")
    for r in res:
        for i in r:
            res_file.write(i["content"] + "\n")
        res_file.write("\n")


if __name__ == "__main__":
    main()