liangyi_LLaMA_Factory / data /preprocess_data /update_fine_grained_instructions.py
Mickey25's picture
Upload folder using huggingface_hub
46b244e verified
#!/usr/bin/env python3
"""
更新已有的细粒度数据到最新的instruction格式
简化instruction,使其更加清晰易懂
"""
import json
import os
from typing import Dict, List, Any
import argparse
import re
# 导入最新的资源结构定义
from generate_fine_grained_data import (
RESOURCE_DETAIL_STRUCTURE,
get_field_description,
get_all_resource_names
)
def create_updated_instruction(field_name: str, field_description: str, resource_name: str = None) -> str:
"""创建更新后的指令格式"""
# 特殊处理resource_names字段
if field_name == "resource_names":
all_resource_names = get_all_resource_names()
resource_list = "、".join(all_resource_names)
return f"""请从OCR文本中抽取旅行订单中的所有资源名称。
可识别的资源名称包括:{resource_list}
资源别称表述:
- 大慈岩丛林速滑(旱滑道)下行:可能表述为"旱滑道"、"丛林速滑"
- 灵栖洞西游魔毯:可能表述为"飞天魔毯"、"灵栖洞魔毯"
- 灵栖洞极速滑道:可能表述为"灵栖洞滑道"、"灵栖洞速滑"
- 考拉森林丛林探险:可能表述为"考拉森林"、"丛林探险"
- 三江口游线(游船):可能表述为"富春江游船"
- 严州古城:可能表述为"梅城"
- 三都渔村婚礼表演:可能表述为"三都渔村九姓渔氏水上婚礼表演"、"九姓渔氏"、"婚礼表演"
资源内在联系规则:
1. 七里扬帆景区:
- 出现"七里扬帆"(无三江口)信息时,包含:七里扬帆-七里扬帆门票、七里扬帆-七里扬帆游船
- 出现"三江口"信息时,包含:七里扬帆-三江口游线(游船),但不包含七里扬帆门票和七里扬帆游船
- 出现"葫芦峡漂流"信息时,包含:七里扬帆-七里扬帆门票、七里扬帆-七里扬帆游船
2. 灵栖洞景区:
- 出现"灵栖洞"信息时,包含:灵栖洞-灵栖洞门票、灵栖洞-灵栖洞手划船
严格按照以下JSON格式输出:
{{
"resource_names": ["资源主体-资源名称1", "资源主体-资源名称2"] 或 []
}}"""
# 特殊处理resource_detail字段
if field_name == "resource_detail" and resource_name:
detail_structure = RESOURCE_DETAIL_STRUCTURE.get(resource_name, {})
if detail_structure:
field_lines = []
for field_type, enum_values in detail_structure.items():
if enum_values:
enum_list = "、".join(f'"{str(v)}"' for v in enum_values)
field_lines.append(f"{field_type} ({get_field_description(field_type)}): {enum_list}")
else:
field_lines.append(f"{field_type} ({get_field_description(field_type)}): null")
fields_info = "\n".join(field_lines)
# 构建JSON示例结构
json_fields = []
for field_type in detail_structure.keys():
json_fields.append(f' "{field_type}": 选择值或null')
json_structure = ",\n".join(json_fields)
return f"""请从OCR文本中抽取旅行订单中{resource_name}的详细信息。
提取字段及可选值:
{fields_info}
严格按照以下JSON格式输出:
{{
"resource_detail": {{
{json_structure}
}}
}}"""
else:
return f"""请从OCR文本中抽取旅行订单中{resource_name}的详细信息。
严格按照以下JSON格式输出:
{{
"resource_detail": {{}}
}}"""
# 基础字段的指令定义
field_instructions = {
"team_size": """请从OCR文本中抽取旅行订单的总人数信息。
严格按照以下JSON格式输出:
{
"team_size": 整数或null
}
注意:订单信息中的导游员、讲解员、司机、领队等人员,不包含在总人数中。
""",
"start_date": """请从OCR文本中抽取旅行订单的开始日期信息。
严格按照以下JSON格式输出:
{
"start_date": "YYYY-MM-DD"或null
}""",
"end_date": """请从OCR文本中抽取旅行订单的结束日期信息。
严格按照以下JSON格式输出:
{
"end_date": "YYYY-MM-DD"或null
}""",
"payment_method": """请从OCR文本中抽取旅行订单的支付方式信息。
严格按照以下JSON格式输出:
{
"payment_method": "支付方式"或null
}""",
"customer_name": """请从OCR文本中抽取旅行订单的客户名称(通常是旅行社或公司名称)。
严格按照以下JSON格式输出:
{
"customer_name": "客户名称"或null
}""",
"customer_market": """请从OCR文本中抽取旅行订单的客户地区,按照'省-市'格式。
严格按照以下JSON格式输出:
{
"customer_market": "省-市"或null
}""",
"customer_type": """请从OCR文本中抽取旅行订单的客户类型(如旅行社、机构等)。
严格按照以下JSON格式输出:
{
"customer_type": "客户类型"或null
}""",
"notes": """请从OCR文本中抽取旅行订单的备注信息。
严格按照以下JSON格式输出:
{
"notes": "备注信息"或null
}""",
"contacts": """请从OCR文本中抽取旅行订单的联系人信息。
提取字段:
name (联系人姓名)
phone (联系电话)
idcard (身份证号码,如果没有则为null)
严格按照以下JSON格式输出:
{
"contacts": {
"data": [
{
"name": "姓名",
"phone": "电话",
"idcard": "身份证号"或null
}
]
}
}""",
"resource_start_time": f"""请从OCR文本中抽取旅行订单中{resource_name or '该资源'}的开始时间。
严格按照以下JSON格式输出:
{{
"resource_start_time": "YYYY-MM-DD"或null
}}""",
"resource_end_time": f"""请从OCR文本中抽取旅行订单中{resource_name or '该资源'}的结束时间。
严格按照以下JSON格式输出:
{{
"resource_end_time": "YYYY-MM-DD"或null
}}""",
"resource_team_size": f"""请从OCR文本中抽取旅行订单中{resource_name or '该资源'}的使用人数。
严格按照以下JSON格式输出:
{{
"resource_team_size": 整数或null
}}
注意:订单信息中的导游员、讲解员、司机、领队等人员,不包含在总人数中。
"""
}
return field_instructions.get(field_name, f"""请从OCR文本中抽取旅行订单的{field_description}信息。
严格按照以下JSON格式输出:
{{
"{field_name}": "值"或null
}}""")
def analyze_old_instruction(instruction: str) -> Dict[str, str]:
"""分析旧的instruction,提取字段信息"""
result = {
"field_name": None,
"field_description": None,
"resource_name": None
}
# 分析输出格式来确定字段名
if '"team_size"' in instruction:
result["field_name"] = "team_size"
result["field_description"] = "总人数"
elif '"start_date"' in instruction:
result["field_name"] = "start_date"
result["field_description"] = "开始日期"
elif '"end_date"' in instruction:
result["field_name"] = "end_date"
result["field_description"] = "结束日期"
elif '"payment_method"' in instruction:
result["field_name"] = "payment_method"
result["field_description"] = "支付方式"
elif '"customer_name"' in instruction:
result["field_name"] = "customer_name"
result["field_description"] = "客户名称"
elif '"customer_market"' in instruction:
result["field_name"] = "customer_market"
result["field_description"] = "客户地区"
elif '"customer_type"' in instruction:
result["field_name"] = "customer_type"
result["field_description"] = "客户类型"
elif '"notes"' in instruction:
result["field_name"] = "notes"
result["field_description"] = "备注"
elif '"contacts"' in instruction:
result["field_name"] = "contacts"
result["field_description"] = "联系人信息"
elif '"resource_names"' in instruction:
result["field_name"] = "resource_names"
result["field_description"] = "资源名称列表"
elif '"resource_start_time"' in instruction:
result["field_name"] = "resource_start_time"
result["field_description"] = "开始时间"
elif '"resource_end_time"' in instruction:
result["field_name"] = "resource_end_time"
result["field_description"] = "结束时间"
elif '"resource_team_size"' in instruction:
result["field_name"] = "resource_team_size"
result["field_description"] = "使用人数"
elif '"resource_detail"' in instruction:
result["field_name"] = "resource_detail"
result["field_description"] = "详细信息"
# 提取资源名称(如果有的话)
if "请从OCR文本中抽取旅行订单中" in instruction and result["field_name"] in ["resource_start_time", "resource_end_time", "resource_team_size", "resource_detail"]:
# 尝试从instruction中提取资源名称
lines = instruction.split('\n')
for line in lines:
if "请从OCR文本中抽取旅行订单中" in line and "的" in line:
# 找到资源名称
start_idx = line.find("请从OCR文本中抽取旅行订单中") + len("请从OCR文本中抽取旅行订单中")
end_idx = line.find("的", start_idx)
if end_idx > start_idx:
resource_part = line[start_idx:end_idx]
result["resource_name"] = resource_part
break
return result
def parse_old_output(output: str, field_name: str) -> Any:
"""解析旧的output格式,提取实际值"""
if not output:
return None
# 1) 优先按整体JSON解析(适配例如: "{"team_size": 36}" 这种字符串)
try:
parsed = json.loads(output) if isinstance(output, str) else output
if isinstance(parsed, dict) and field_name in parsed:
return parsed[field_name]
except Exception:
pass
# 2) 回退到基于键名的提取逻辑,尽量健壮地抽取值片段
if isinstance(output, str) and f'"{field_name}"' in output:
# 使用正则尝试获取字段值的原始JSON片段
# 匹配 null/布尔/数字/字符串/对象/数组
pattern = rf'"{re.escape(field_name)}"\s*:\s*(null|true|false|-?\d+(?:\.\d+)?|"[^"\\]*(?:\\.[^"\\]*)*"|\{{[^\}}]*\}}|\[[^\]]*\])'
match = re.search(pattern, output, flags=re.IGNORECASE | re.DOTALL)
if match:
raw_value = match.group(1).strip()
# 先尝试按JSON解析
try:
return json.loads(raw_value)
except Exception:
# 处理裸值
if raw_value == 'null':
return None
if raw_value.startswith('"') and raw_value.endswith('"'):
return raw_value[1:-1]
# 数字
try:
return int(raw_value)
except Exception:
try:
return float(raw_value)
except Exception:
return raw_value
return None
def format_new_output(value: Any, field_name: str) -> str:
"""格式化新的output为标准JSON格式"""
result = {field_name: value}
return json.dumps(result, ensure_ascii=False)
def update_fine_grained_data(input_file: str, output_file: str):
"""更新细粒度数据到最新格式"""
print(f"正在加载数据: {input_file}")
# 加载数据 - 支持JSON数组格式和JSONL格式
data = []
is_json_array_format = False # 记录输入格式类型
with open(input_file, 'r', encoding='utf-8') as f:
content = f.read().strip()
# 检测文件格式
if content.startswith('[') and content.endswith(']'):
# JSON数组格式
try:
data = json.loads(content)
is_json_array_format = True
print(f"检测到JSON数组格式")
except json.JSONDecodeError as e:
print(f"JSON数组解析失败: {e}")
return
else:
# JSONL格式 - 按行解析
print(f"检测到JSONL格式")
f.seek(0) # 重置文件指针
for line_num, line in enumerate(f, 1):
line = line.strip()
if not line:
continue
try:
item = json.loads(line)
data.append(item)
except json.JSONDecodeError as e:
print(f"警告: 第{line_num}行JSON解析失败: {e}")
continue
print(f"加载了 {len(data)} 条数据")
# 更新数据
updated_data = []
update_stats = {}
for item in data:
old_instruction = item["instruction"]
old_output = item["output"]
# 分析旧instruction
analysis = analyze_old_instruction(old_instruction)
if analysis["field_name"]:
# 生成新instruction
new_instruction = create_updated_instruction(
analysis["field_name"],
analysis["field_description"],
analysis["resource_name"]
)
# 解析并转换output格式
parsed_value = parse_old_output(old_output, analysis["field_name"])
new_output = format_new_output(parsed_value, analysis["field_name"])
# 更新数据项:在原条目基础上修改,保留未知字段(例如 context)
updated_item = dict(item)
updated_item["instruction"] = new_instruction
updated_item["output"] = new_output
updated_data.append(updated_item)
# 统计更新情况
field_key = analysis["field_name"]
if analysis["resource_name"]:
field_key = f"{analysis['field_name']}({analysis['resource_name']})"
update_stats[field_key] = update_stats.get(field_key, 0) + 1
else:
# 无法识别的格式,保持原样
updated_data.append(item)
update_stats["未识别"] = update_stats.get("未识别", 0) + 1
# 保存更新后的数据 - 根据输入格式选择输出格式
print(f"正在保存更新后的数据: {output_file}")
with open(output_file, 'w', encoding='utf-8') as f:
if is_json_array_format:
# 输出JSON数组格式
print("保存为JSON数组格式")
json.dump(updated_data, f, ensure_ascii=False, indent=2)
else:
# 输出JSONL格式
print("保存为JSONL格式")
for item in updated_data:
f.write(json.dumps(item, ensure_ascii=False) + '\n')
# 输出统计信息
print(f"\n更新完成!共处理 {len(updated_data)} 条数据")
print("\n各字段更新统计:")
for field, count in sorted(update_stats.items()):
print(f" {field}: {count} 条")
def main():
parser = argparse.ArgumentParser(description='更新细粒度数据到最新instruction格式')
parser.add_argument('--input', '-i', type=str, required=True,
help='输入文件路径')
parser.add_argument('--output', '-o', type=str, required=True,
help='输出文件路径')
args = parser.parse_args()
# 检查输入文件是否存在
if not os.path.exists(args.input):
print(f"错误: 输入文件不存在: {args.input}")
return
# 创建输出目录(如果不存在)
output_dir = os.path.dirname(args.output)
if output_dir and not os.path.exists(output_dir):
os.makedirs(output_dir)
try:
update_fine_grained_data(args.input, args.output)
except Exception as e:
print(f"更新失败: {e}")
import traceback
traceback.print_exc()
if __name__ == "__main__":
main()