xiaoximew's picture
Init
031beb8
import json
import os
from typing import Tuple
import pandas as pd
from common.call_llm import chat
EXTRACT_ENDPOINT = os.environ.get("EXTRACT_ENDPOINT")
prompt_template = """### 角色能力 ###
你是一个信息提取助手,你可以按下面给出的提取字段及描述对文档内容进行信息提取,并按给定的格式返回。
确保提取的信息完整且与文档内容一致,如果有字段提取不到对应信息请返回'无'。
### 提取字段及描述 ###
{fields_prompt}
### 文档内容 ###
{context}
### 返回格式 ###
请严格按照下面描述的JSON格式进行输出,不需要解释,输出JSON格式如下:
{response_prompt}
确保输出的格式可以被Python的json.loads方法解析。
"""
def extract_slots(page_content: str, extraction_df: "pd.DataFrame") -> Tuple[str, None]:
"""
Extract slots from page content
:param page_content:
:param extract_requirement:
:return:
"""
extract_requirement = ""
output_requirement = dict()
df = pd.DataFrame(columns=["字段名称", "字段抽取结果"])
# remove nan
extraction_df = extraction_df[extraction_df['字段名称'].notna()]
for _, row in extraction_df.iterrows():
if not row['字段名称'] or not row['字段描述']:
continue
extract_requirement += f"{row['字段名称']}: {row['字段描述']}\n"
output_requirement[row['字段名称']] = row['字段描述']
if not output_requirement:
return df
output_requirement_description = json.dumps([output_requirement], ensure_ascii=False, indent=4)
prompt = prompt_template.format(context=page_content, fields_prompt=extract_requirement, response_prompt=output_requirement_description)
messages = [{"role": "user", "content": prompt}]
max_retry = 6
retry = 0
result = None
while not result and retry < max_retry:
try:
result = chat(messages=messages, endpoint=EXTRACT_ENDPOINT)
if result.startswith("```json"):
result = result.replace("```json", "").replace("```", "").strip()
result = json.loads(result)
if isinstance(result, list):
result = result[0]
except Exception as e:
print(f"error: {e} {result}")
result = None
retry += 1
print(f"extract slots prompt: {prompt} result: {result}")
if not result:
return df
for field in output_requirement:
df.loc[len(df)] = {"字段名称": field, "字段抽取结果": result.get(field, "无")}
return df