Spaces:
Runtime error
Runtime error
import re | |
import sys | |
import requests | |
import logging | |
from typing import List, Tuple | |
from emb import generate_embeddings_from_qa_pairs | |
# 设置日志记录 | |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
# OpenAI API配置 | |
API_URL = "" | |
API_KEY = "" | |
MODEL = "gpt-3.5-turbo" | |
# 问答提示模板 | |
Prompt_AgentQA = { | |
"description": "<Context></Context> 标记中是一段文本,学习和分析它,并整理学习成果:\n- 提出问题并给出每个问题的答案。\n- 你的拆分整理不得包含除了文本以外的知识内容,你的所有答案都只能从文本中获取\n- 答案需详细完整,尽可能保留原文描述。\n- 答案可以包含普通文字、链接、代码、表格、公示、媒体链接等 Markdown 元素。\n- 最多提出 40 个问题。\n", | |
"fixedText": "请按以下格式整理学习成果:\n<Context>\n文本\n</Context>\nQ1: 问题。\nA1: 答案。\nQ2:\nA2:\n\n------\n\n我们开始吧!\n\n<Context>\n{{{(text)}}}\n</Context>\n" | |
} | |
def replace_variable(text: str, obj: dict) -> str: | |
for key, val in obj.items(): | |
if isinstance(val, (str, int)): | |
text = text.replace(f"{{{{({key})}}}}", str(val)) | |
return text or '' | |
def generate_qa(text: str) -> List[Tuple[str, str]]: | |
system_prompt = f"{Prompt_AgentQA['description']}" | |
user_prompt = f"{replace_variable(Prompt_AgentQA['fixedText'], {'text': text})}" | |
headers = { | |
"Content-Type": "application/json", | |
"Authorization": f"Bearer {API_KEY}" | |
} | |
data = { | |
"model": MODEL, | |
"messages": [ | |
{ | |
"role": "system", | |
"content": system_prompt | |
}, | |
{ | |
"role": "user", | |
"content": user_prompt | |
} | |
], | |
"temperature": 0.3, | |
"stream": False | |
} | |
try: | |
response = requests.post(API_URL, headers=headers, json=data) | |
response.raise_for_status() | |
answer = response.json()["choices"][0]["message"]["content"] | |
except requests.exceptions.RequestException as e: | |
logging.error(f"OpenAI API request failed: {e}") | |
return [] | |
return format_split_text(answer, text) | |
def format_split_text(text: str, raw_text: str) -> List[Tuple[str, str]]: | |
text = text.replace(r"\\n", "\n") | |
regex = r"Q\d+:(\s*)(.*)(\s*)A\d+:(\s*)([\s\S]*?)(?=Q\d|$)" # 匹配Q和A的正则表达式 | |
matches = re.findall(regex, text) | |
result = [] | |
for match in matches: | |
q = match[1].strip() | |
a = match[4].strip() | |
if q: | |
result.append((q, a)) | |
# 如果结果为空,直接将原文本拆分为块 | |
if not result: | |
chunks = [raw_text[i:i+512] for i in range(0, len(raw_text), 512)] | |
result = [(chunk, '') for chunk in chunks] | |
return result | |
def main(text, api_key, api_url_base): | |
global API_KEY, API_URL | |
API_KEY = api_key | |
API_URL = f"{api_url_base}/chat/completions" | |
qa_pairs = generate_qa(text) | |
generate_embeddings_from_qa_pairs(qa_pairs, api_key, api_url_base) | |
return qa_pairs | |
if __name__ == "__main__": | |
text = sys.argv[1] # 从命令行获取搜索结果作为text变量 | |
api_key = sys.argv[2] | |
api_url_base = sys.argv[3] | |
main(text, api_key, api_url_base) | |