import re import sys import requests import logging from typing import List, Tuple from emb import generate_embeddings_from_qa_pairs # 设置日志记录 logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') # OpenAI API配置 API_URL = "" API_KEY = "" MODEL = "gpt-3.5-turbo" # 问答提示模板 Prompt_AgentQA = { "description": " 标记中是一段文本,学习和分析它,并整理学习成果:\n- 提出问题并给出每个问题的答案。\n- 你的拆分整理不得包含除了文本以外的知识内容,你的所有答案都只能从文本中获取\n- 答案需详细完整,尽可能保留原文描述。\n- 答案可以包含普通文字、链接、代码、表格、公示、媒体链接等 Markdown 元素。\n- 最多提出 40 个问题。\n", "fixedText": "请按以下格式整理学习成果:\n\n文本\n\nQ1: 问题。\nA1: 答案。\nQ2:\nA2:\n\n------\n\n我们开始吧!\n\n\n{{{(text)}}}\n\n" } def replace_variable(text: str, obj: dict) -> str: for key, val in obj.items(): if isinstance(val, (str, int)): text = text.replace(f"{{{{({key})}}}}", str(val)) return text or '' def generate_qa(text: str) -> List[Tuple[str, str]]: system_prompt = f"{Prompt_AgentQA['description']}" user_prompt = f"{replace_variable(Prompt_AgentQA['fixedText'], {'text': text})}" headers = { "Content-Type": "application/json", "Authorization": f"Bearer {API_KEY}" } data = { "model": MODEL, "messages": [ { "role": "system", "content": system_prompt }, { "role": "user", "content": user_prompt } ], "temperature": 0.3, "stream": False } try: response = requests.post(API_URL, headers=headers, json=data) response.raise_for_status() answer = response.json()["choices"][0]["message"]["content"] except requests.exceptions.RequestException as e: logging.error(f"OpenAI API request failed: {e}") return [] return format_split_text(answer, text) def format_split_text(text: str, raw_text: str) -> List[Tuple[str, str]]: text = text.replace(r"\\n", "\n") regex = r"Q\d+:(\s*)(.*)(\s*)A\d+:(\s*)([\s\S]*?)(?=Q\d|$)" # 匹配Q和A的正则表达式 matches = re.findall(regex, text) result = [] for match in matches: q = match[1].strip() a = match[4].strip() if q: result.append((q, a)) # 如果结果为空,直接将原文本拆分为块 if not result: chunks = [raw_text[i:i+512] for i in range(0, len(raw_text), 512)] result = [(chunk, '') for chunk in chunks] return result def main(text, api_key, api_url_base): global API_KEY, API_URL API_KEY = api_key API_URL = f"{api_url_base}/chat/completions" qa_pairs = generate_qa(text) generate_embeddings_from_qa_pairs(qa_pairs, api_key, api_url_base) return qa_pairs if __name__ == "__main__": text = sys.argv[1] # 从命令行获取搜索结果作为text变量 api_key = sys.argv[2] api_url_base = sys.argv[3] main(text, api_key, api_url_base)