File size: 6,548 Bytes
90126b7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
from openai import OpenAI
from concurrent.futures import ThreadPoolExecutor
import json
import copy
from tqdm import tqdm
import queue
import time

base_id_prompt = "# Role: 问答机器人\n\n## Profile\n- author: 尖米\n- version: 1.0\n- language: 中文\n- description: 你是机智流的问答机器人,你可以对用户输入的图像、文字进行解析,并根据已有的知识库进行精确回答。\n\n## Skills\n1. 图像识别与解析:能够识别用户上传的图像,并提取其中的关键信息。\n2. 自然语言处理:能够理解并解析用户输入的文字信息,准确把握用户意图。\n3. 知识库应用:根据解析结果,查询知识库,提供准确、相关的答案。\n4. 多轮对话:支持与用户进行多轮对话,提供连续性、上下文相关的回答。\n\n## Rules\n1. 必须充分理解用户输入的图像和文字内容。\n2. 回答需要简洁明了,避免过于复杂或含糊的表述。\n3. 在回答过程中,优先查询和引用公司已有的知识库。\n4. 对于无法回答的问题,需要引导用户提供更多信息或寻求人工客服帮助。\n\n## Workflows\n1. 接收并分析用户输入的图像或文字信息。\n2. 基于图像识别或自然语言处理技术,提取关键信息。\n3. 查询知识库,匹配相关信息。\n4. 向用户提供精准、相关的回答。\n5. 如有必要,进行多轮对话,确保问题得到有效解决。\n\n## Init\n欢迎使用机智流的问答机器人,请输入您的问题,我将尽力为您提供帮助。\n",

# 定义客户端
clients = {
    "internlm": OpenAI(
        api_key="your_internlm_api_key",
        base_url="https://internlm-chat.intern-ai.org.cn/puyu/api/v1/",
    ),
    "glm": OpenAI(
        api_key="your_glm_api_key",
        base_url="your_glm_url",
    ),
    "deepseek": OpenAI(
        api_key="your_deepseek_api_key",
        base_url="your_deepseek_url",
    )
}

class BaseDataAPI:
    def __init__(self, questions_path, save_path, repeat=0, client_name="internlm"):
        self.client = clients[client_name]
        self.questions_path = questions_path
        self.save_path = save_path
        self.repeat = repeat
        self.data_template = {
            "conversation": [
                {
                    "system": base_id_prompt
                    "input": "xxx",
                    "output": "xxx"
                }
            ]
        }

    def get_answer(self, question):
        chat_rsp = self.client.chat.completions.create(
            model="internlm2.5-latest",  # 或 "internlm2-latest" 或 "glm-4"
            messages=[
                {"role": "system", "content": base_id_prompt},
                {"role": "user", "content": question}
            ],
            stream=False,
        )
        return self.build_data(question, chat_rsp)

    def build_data(self, question, chat_rsp):
        temp = copy.deepcopy(self.data_template)
        temp['conversation'][0]['input'] = question
        temp['conversation'][0]['output'] = chat_rsp.choices[0].message.content
        return temp

    def save(self, train_data):
        with open(self.save_path, 'a', encoding='utf-8') as f:
            for item in train_data:
                json.dump(item, f, ensure_ascii=False)
                f.write("\n")

    @staticmethod
    def load_txt(path):
        with open(path, 'r', encoding='utf-8') as f:
            return f.read()

    def read_questions(self):
        prompt = self.load_txt(self.questions_path)
        promptlist = prompt.split('\n')
        if self.repeat != 0:
            promptlist = promptlist * self.repeat
        print(f"Total questions: {len(promptlist)}")
        return promptlist

class GetDataApi(BaseDataAPI):
    def run(self):
        answer_queue = queue.Queue()
        promptlist = self.read_questions()
        with ThreadPoolExecutor(max_workers=10) as pool:
            print("Asking...")
            futures = [pool.submit(self.get_answer, question) for question in promptlist]
            for future in tqdm(futures):
                result = future.result()
                answer_queue.put(result)
                if answer_queue.qsize() >= 10:  # 每10个问题保存一次
                    self.save([answer_queue.get() for _ in range(10)])

        # 保存剩余的回答
        remaining = []
        while not answer_queue.empty():
            remaining.append(answer_queue.get())
        if remaining:
            self.save(remaining)

class ChatData(BaseDataAPI):
    def __init__(self, train_data, save_path, client_name="internlm"):
        super().__init__(train_data, save_path, client_name=client_name)
        self.train_data = train_data

    def load_data(self):
        with open(self.train_data, 'r', encoding='utf-8') as f:
            return f.readlines()

    def ask_for_tts(self, question, save_ask):
        chat_rsp = self.client.chat.completions.create(
            model="internlm2.5-latest",  # 或 "glm-4"
            messages=[
                {"role": "system", "content": base_id_prompt},
                {"role": "user", "content": question}
            ],
            stream=False,
        )
        return self.build_data(save_ask, chat_rsp)

    def __call__(self):
        train_data = self.load_data()
        answer_queue = queue.Queue()
        with ThreadPoolExecutor(max_workers=10) as pool:
            print("Asking...")
            futures = []
            for item in train_data:
                item = json.loads(item)
                question = item['conversation'][0]['output']
                save_ask = item['conversation'][0]['input']
                futures.append(pool.submit(self.ask_for_tts, question, save_ask))

            for future in tqdm(futures):
                result = future.result()
                answer_queue.put(result)
                if answer_queue.qsize() >= 10:  # 每10个问题保存一次
                    self.save([answer_queue.get() for _ in range(10)])

        # 保存剩余的回答
        remaining = []
        while not answer_queue.empty():
            remaining.append(answer_queue.get())
        if remaining:
            self.save(remaining)

if __name__ == '__main__':
    questions_path = './tools/L1_XTuner_code/Q_list.txt'
    save_path = './data/train_basic.jsonl'
    start_time = time.time()
    chat_data = GetDataApi(questions_path, save_path)
    chat_data()
    end_time = time.time()
    print('Done')
    print(f'Time used: {end_time - start_time:.2f} seconds')