inksiyu's picture
Upload 14 files
c871381 verified
import re
import sys
import requests
import logging
from typing import List, Tuple
from emb import generate_embeddings_from_qa_pairs
# 设置日志记录
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
# OpenAI API配置
API_URL = ""
API_KEY = ""
MODEL = "gpt-3.5-turbo"
# 问答提示模板
Prompt_AgentQA = {
"description": "<Context></Context> 标记中是一段文本,学习和分析它,并整理学习成果:\n- 提出问题并给出每个问题的答案。\n- 你的拆分整理不得包含除了文本以外的知识内容,你的所有答案都只能从文本中获取\n- 答案需详细完整,尽可能保留原文描述。\n- 答案可以包含普通文字、链接、代码、表格、公示、媒体链接等 Markdown 元素。\n- 最多提出 40 个问题。\n",
"fixedText": "请按以下格式整理学习成果:\n<Context>\n文本\n</Context>\nQ1: 问题。\nA1: 答案。\nQ2:\nA2:\n\n------\n\n我们开始吧!\n\n<Context>\n{{{(text)}}}\n</Context>\n"
}
def replace_variable(text: str, obj: dict) -> str:
for key, val in obj.items():
if isinstance(val, (str, int)):
text = text.replace(f"{{{{({key})}}}}", str(val))
return text or ''
def generate_qa(text: str) -> List[Tuple[str, str]]:
system_prompt = f"{Prompt_AgentQA['description']}"
user_prompt = f"{replace_variable(Prompt_AgentQA['fixedText'], {'text': text})}"
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {API_KEY}"
}
data = {
"model": MODEL,
"messages": [
{
"role": "system",
"content": system_prompt
},
{
"role": "user",
"content": user_prompt
}
],
"temperature": 0.3,
"stream": False
}
try:
response = requests.post(API_URL, headers=headers, json=data)
response.raise_for_status()
answer = response.json()["choices"][0]["message"]["content"]
except requests.exceptions.RequestException as e:
logging.error(f"OpenAI API request failed: {e}")
return []
return format_split_text(answer, text)
def format_split_text(text: str, raw_text: str) -> List[Tuple[str, str]]:
text = text.replace(r"\\n", "\n")
regex = r"Q\d+:(\s*)(.*)(\s*)A\d+:(\s*)([\s\S]*?)(?=Q\d|$)" # 匹配Q和A的正则表达式
matches = re.findall(regex, text)
result = []
for match in matches:
q = match[1].strip()
a = match[4].strip()
if q:
result.append((q, a))
# 如果结果为空,直接将原文本拆分为块
if not result:
chunks = [raw_text[i:i+512] for i in range(0, len(raw_text), 512)]
result = [(chunk, '') for chunk in chunks]
return result
def main(text, api_key, api_url_base):
global API_KEY, API_URL
API_KEY = api_key
API_URL = f"{api_url_base}/chat/completions"
qa_pairs = generate_qa(text)
generate_embeddings_from_qa_pairs(qa_pairs, api_key, api_url_base)
return qa_pairs
if __name__ == "__main__":
text = sys.argv[1] # 从命令行获取搜索结果作为text变量
api_key = sys.argv[2]
api_url_base = sys.argv[3]
main(text, api_key, api_url_base)