|
import json |
|
import os |
|
from typing import Tuple |
|
|
|
import pandas as pd |
|
|
|
from common.call_llm import chat |
|
|
|
EXTRACT_ENDPOINT = os.environ.get("EXTRACT_ENDPOINT") |
|
|
|
|
|
prompt_template = """### 角色能力 ### |
|
你是一个信息提取助手,你可以按下面给出的提取字段及描述对文档内容进行信息提取,并按给定的格式返回。 |
|
确保提取的信息完整且与文档内容一致,如果有字段提取不到对应信息请返回'无'。 |
|
|
|
### 提取字段及描述 ### |
|
{fields_prompt} |
|
|
|
### 文档内容 ### |
|
{context} |
|
|
|
### 返回格式 ### |
|
请严格按照下面描述的JSON格式进行输出,不需要解释,输出JSON格式如下: |
|
{response_prompt} |
|
确保输出的格式可以被Python的json.loads方法解析。 |
|
""" |
|
|
|
|
|
def extract_slots(page_content: str, extraction_df: "pd.DataFrame") -> Tuple[str, None]: |
|
""" |
|
Extract slots from page content |
|
:param page_content: |
|
:param extract_requirement: |
|
:return: |
|
""" |
|
extract_requirement = "" |
|
output_requirement = dict() |
|
df = pd.DataFrame(columns=["字段名称", "字段抽取结果"]) |
|
|
|
|
|
extraction_df = extraction_df[extraction_df['字段名称'].notna()] |
|
|
|
for _, row in extraction_df.iterrows(): |
|
if not row['字段名称'] or not row['字段描述']: |
|
continue |
|
|
|
extract_requirement += f"{row['字段名称']}: {row['字段描述']}\n" |
|
output_requirement[row['字段名称']] = row['字段描述'] |
|
|
|
if not output_requirement: |
|
return df |
|
|
|
output_requirement_description = json.dumps([output_requirement], ensure_ascii=False, indent=4) |
|
prompt = prompt_template.format(context=page_content, fields_prompt=extract_requirement, response_prompt=output_requirement_description) |
|
messages = [{"role": "user", "content": prompt}] |
|
|
|
max_retry = 6 |
|
retry = 0 |
|
result = None |
|
while not result and retry < max_retry: |
|
try: |
|
result = chat(messages=messages, endpoint=EXTRACT_ENDPOINT) |
|
if result.startswith("```json"): |
|
result = result.replace("```json", "").replace("```", "").strip() |
|
|
|
result = json.loads(result) |
|
if isinstance(result, list): |
|
result = result[0] |
|
except Exception as e: |
|
print(f"error: {e} {result}") |
|
result = None |
|
retry += 1 |
|
|
|
print(f"extract slots prompt: {prompt} result: {result}") |
|
|
|
if not result: |
|
return df |
|
|
|
for field in output_requirement: |
|
df.loc[len(df)] = {"字段名称": field, "字段抽取结果": result.get(field, "无")} |
|
|
|
return df |
|
|