MikeMai's picture
Update app.py
951ce76 verified
import os
from dotenv import load_dotenv
load_dotenv()
import json
import pandas as pd
import zipfile
import xml.etree.ElementTree as ET
from io import BytesIO
import openpyxl
from openai import OpenAI
import re
from pydantic import BaseModel, Field, ValidationError, RootModel
from typing import List, Optional
from fuzzywuzzy import fuzz
from fuzzywuzzy import process
HF_API_KEY = os.getenv("HF_API_KEY")
# Deepseek R1 Distilled Qwen 2.5 14B --------------------------------
# base_url = "https://router.huggingface.co/novita"
# model = "deepseek/deepseek-r1-distill-qwen-14b"
# Deepseek R1 Distilled Qwen 2.5 32B --------------------------------
# base_url = "https://router.huggingface.co/hf-inference/models/deepseek-ai/DeepSeek-R1-Distill-Qwen-32B/v1"
# model = "deepseek-ai/DeepSeek-R1-Distill-Qwen-32B"
# Qwen 2.5 7B --------------------------------------------------------
base_url = "https://router.huggingface.co/together/v1"
model= "Qwen/Qwen2.5-7B-Instruct-Turbo"
# Qwen 2.5 32B --------------------------------------------------------
# base_url = "https://router.huggingface.co/novita/v3/openai"
# model="qwen/qwen-2.5-72b-instruct"
# Qwen 3 32B --------------------------------------------------------
# base_url = "https://router.huggingface.co/sambanova/v1"
# model="Qwen3-32B"
# Default Word XML namespace
DEFAULT_NS = {'w': 'http://schemas.openxmlformats.org/wordprocessingml/2006/main'}
NS = None # Global variable to store the namespace
def get_namespace(root):
"""Extracts the primary namespace from the XML root element while keeping the default."""
global NS
ns = root.tag.split('}')[0].strip('{')
NS = {'w': ns} if ns else DEFAULT_NS
return NS
# --- Helper Functions for DOCX Processing ---
def extract_text_from_cell(cell):
"""Extracts text from a Word table cell, preserving line breaks and reconstructing split words."""
paragraphs = cell.findall('.//w:p', NS)
lines = []
for paragraph in paragraphs:
# Get all text runs and concatenate their contents
text_runs = [t.text for t in paragraph.findall('.//w:t', NS) if t.text]
line = ''.join(text_runs).strip() # Merge split words properly
if line: # Add only non-empty lines
lines.append(line)
return lines # Return list of lines to preserve line breaks
def clean_spaces(text):
r"""
Removes excessive spaces between Chinese characters while preserving spaces in English words.
Also normalizes multiple spaces to single space and ensures one space between Chinese and English.
"""
if not text or not isinstance(text, str):
return text
# Remove spaces between Chinese characters
text = re.sub(r'([\u4e00-\u9fff])\s+([\u4e00-\u9fff])', r'\1\2', text)
# Ensure one space between Chinese and English
text = re.sub(r'([\u4e00-\u9fff])\s*([a-zA-Z])', r'\1 \2', text)
text = re.sub(r'([a-zA-Z])\s*([\u4e00-\u9fff])', r'\1 \2', text)
# Normalize multiple spaces to single space
text = re.sub(r'\s+', ' ', text)
return text.strip()
def extract_key_value_pairs(text, target_dict=None):
"""
Extracts multiple key-value pairs from a given text.
- First, split by more than 3 spaces (`\s{3,}`) **only if the next segment contains a `:`.**
- Then, process each segment by splitting at `:` to correctly assign keys and values.
"""
if target_dict is None:
target_dict = {}
text = text.replace(":", ":") # Normalize Chinese colons to English colons
# Step 1: Check if splitting by more than 3 spaces is necessary
segments = re.split(r'(\s{3,})', text) # Use raw string to prevent invalid escape sequence
# Step 2: Process each segment, ensuring we only split if the next part has a `:`
merged_segments = []
temp_segment = ""
for segment in segments:
if ":" in segment: # If segment contains `:`, it's a valid split point
if temp_segment:
merged_segments.append(temp_segment.strip())
temp_segment = ""
merged_segments.append(segment.strip())
else:
temp_segment += " " + segment.strip()
if temp_segment:
merged_segments.append(temp_segment.strip())
# Step 3: Extract key-value pairs correctly
for segment in merged_segments:
if ':' in segment:
key, value = segment.split(':', 1) # Only split at the first colon
key, value = key.strip(), value.strip() # Clean spaces
if key in target_dict:
target_dict[key] += "\n" + value # Append if key already exists
else:
target_dict[key] = value
return target_dict
# --- Table Processing Functions ---
def process_unknown_table(rows):
"""Processes unknown tables and returns the extracted lines as a list."""
unknown_table_data = []
for row in rows:
cells = row.findall('.//w:tc', NS)
if len(cells) == 1:
cell_lines = extract_text_from_cell(cells[0]) # Extract all lines from the cell
# Append each line directly to the list without splitting
unknown_table_data.extend(cell_lines)
return unknown_table_data # Return the list of extracted lines
def process_buyer_seller_table(rows):
"""Processes a two-column buyer-seller table into a structured dictionary using the first row as keys."""
headers = [extract_text_from_cell(cell) for cell in rows[0].findall('.//w:tc', NS)]
if len(headers) != 2:
return None # Not a buyer-seller table
# determine role based on header text
def get_role(header_text, default_role):
header_text = header_text.lower() # Convert to lowercase
if '买方' in header_text or 'buyer' in header_text or '甲方' in header_text:
return 'buyer_info'
elif '卖方' in header_text or 'seller' in header_text or '乙方' in header_text:
return 'seller_info'
else:
return default_role # Default if no keyword is found
# Determine the keys for buyer and seller columns
buyer_key = get_role(headers[0][0], 'buyer_info')
seller_key = get_role(headers[1][0], 'seller_info')
# Initialize the dictionary using the determined keys
buyer_seller_data = {
buyer_key: {},
seller_key: {}
}
for row in rows:
cells = row.findall('.//w:tc', NS)
if len(cells) == 2:
buyer_lines = extract_text_from_cell(cells[0])
seller_lines = extract_text_from_cell(cells[1])
for line in buyer_lines:
extract_key_value_pairs(line, buyer_seller_data[buyer_key])
for line in seller_lines:
extract_key_value_pairs(line, buyer_seller_data[seller_key])
return buyer_seller_data
def process_summary_table(rows):
"""Processes a two-column summary table where keys are extracted as dictionary keys."""
extracted_data = []
for row in rows:
cells = row.findall('.//w:tc', NS)
if len(cells) == 2:
key = " ".join(extract_text_from_cell(cells[0]))
value = " ".join(extract_text_from_cell(cells[1]))
extracted_data.append({key: value})
return extracted_data
def clean_header_spaces(headers):
"""
Cleans headers for consistent matching by:
1. Normalizing multiple spaces to single space
2. Ensuring exactly one space between Chinese and English
3. Converting to lowercase
"""
if not headers:
return headers
cleaned_headers = []
for header in headers:
if not header:
cleaned_headers.append(header)
continue
# Normalize multiple spaces to single space
header = re.sub(r'\s+', ' ', header)
# Ensure exactly one space between Chinese and English
header = re.sub(r'([\u4e00-\u9fff])\s*([a-zA-Z])', r'\1 \2', header)
header = re.sub(r'([a-zA-Z])\s*([\u4e00-\u9fff])', r'\1 \2', header)
# Final cleanup of any remaining multiple spaces
header = re.sub(r'\s+', ' ', header)
# Convert to lowercase
header = header.lower()
cleaned_headers.append(header.strip())
return cleaned_headers
def extract_headers(first_row_cells):
"""Extracts unique column headers from the first row of a table."""
headers = []
header_count = {}
for cell in first_row_cells:
cell_text = " ".join(extract_text_from_cell(cell))
grid_span = cell.find('.//w:gridSpan', NS)
col_span = int(grid_span.attrib.get(f'{{{NS["w"]}}}val', '1')) if grid_span is not None else 1
for _ in range(col_span):
# Ensure header uniqueness by appending an index if repeated
if cell_text in header_count:
header_count[cell_text] += 1
unique_header = f"{cell_text}_{header_count[cell_text]}"
else:
header_count[cell_text] = 1
unique_header = cell_text
headers.append(unique_header if unique_header else f"Column_{len(headers) + 1}")
return headers
def process_long_table(rows):
"""Processes a standard table and correctly handles horizontally merged cells."""
if not rows:
return [] # Avoid IndexError
headers = extract_headers(rows[0].findall('.//w:tc', NS))
table_data = []
vertical_merge_tracker = {}
for row in rows[1:]:
row_data = {}
cells = row.findall('.//w:tc', NS)
running_index = 0
# Skip rows with only 1 or 2 columns (merged cells)
if len(cells) <= 2:
continue
for cell in cells:
cell_text = " ".join(extract_text_from_cell(cell))
# Consistent Namespace Handling for Horizontal Merge
grid_span = cell.find('.//w:gridSpan', NS)
grid_span_val = grid_span.attrib.get(f'{{{NS["w"]}}}val') if grid_span is not None else '1'
col_span = int(grid_span_val)
# Handle vertical merge
v_merge = cell.find('.//w:vMerge', NS)
if v_merge is not None:
v_merge_val = v_merge.attrib.get(f'{{{NS["w"]}}}val')
if v_merge_val == 'restart':
vertical_merge_tracker[running_index] = cell_text
else:
# Repeat the value from the previous row's merged cell
cell_text = vertical_merge_tracker.get(running_index, "")
# Repeat the value for horizontally merged cells
start_col = running_index
end_col = running_index + col_span
# Repeat the value for each spanned column
for col in range(start_col, end_col):
key = headers[col] if col < len(headers) else f"Column_{col+1}"
row_data[key] = cell_text
# Update the running index to the end of the merged cell
running_index = end_col
# Fill remaining columns with empty strings to maintain alignment
while running_index < len(headers):
row_data[headers[running_index]] = ""
running_index += 1
table_data.append(row_data)
# Clean the keys in the table data
cleaned_table_data = []
for row in table_data:
cleaned_row = {}
for key, value in row.items():
# Clean the key using the same function we use for headers
cleaned_key = clean_header_spaces([key])[0]
cleaned_row[cleaned_key] = value
cleaned_table_data.append(cleaned_row)
# Filter out rows where the "序号" column contains non-numeric values
filtered_table_data = []
for row in cleaned_table_data:
# Check if any cell contains "合计" (total), "折扣" (discount), or "明细见附件" (details in attachment)
# But exclude the remarks column from this check
contains_total = False
for key, value in row.items():
# Skip if this is a remarks column
key_lower = key.lower()
if any(term in key_lower for term in ["备注", "remarks", "note", "notes"]):
continue # Skip remarks column
if isinstance(value, str) and ("小计" in value or "总金额" in value or "合计" in value or "折扣" in value or "明细见附件" in value):
contains_total = True
break
if contains_total:
continue
# Check potential serial number columns (use both Chinese and English variants)
serial_number = None
for column in row:
if any(term in column.lower() for term in ["序号"]):
serial_number = row[column]
break
# If we found a serial number column, check if its value is numeric
if serial_number is not None:
# Skip if serial number is empty
if not serial_number.strip():
continue
# Strip any non-numeric characters and check if there's still a value
# This keeps values like "1", "2." etc. but filters out "No." or other text
cleaned_number = re.sub(r'[^\d]', '', serial_number)
if cleaned_number: # If there are any digits left, keep the row
filtered_table_data.append(row)
else:
# If we couldn't find a serial number column, keep the row
filtered_table_data.append(row)
# Remove duplicate columns (ending with _2, _3, etc.)
filtered_table_data = merge_duplicate_columns(filtered_table_data)
return filtered_table_data
def identify_table_type_and_header_row(rows):
"""Identify table type and the index of the header row."""
for i, row in enumerate(rows):
num_cells = len(row.findall('.//w:tc', NS))
if num_cells > 1:
# Check for buyer-seller or summary table based on structure only
if num_cells == 2:
if all(len(r.findall('.//w:tc', NS)) == 2 for r in rows):
# Check if it contains buyer/seller keywords
cell_texts = " ".join([" ".join(extract_text_from_cell(cell)) for cell in row.findall('.//w:tc', NS)])
buyer_seller_keywords = ["买方", "buyer", "卖方", "seller"]
if any(keyword.lower() in cell_texts.lower() for keyword in buyer_seller_keywords):
return "buyer_seller", i
else:
return "unknown", i
else:
return "summary", i
else:
# For tables with more than 2 columns, process as long table
return "long_table", i
# Fallbacks
return "unknown", 0
def extract_tables(root):
"""Extracts tables from the DOCX document and returns structured data, skipping tables whose title contains 'template'.
Handles cases where there are blank paragraphs between the title and the table."""
# Find the document body (usually the first child of root)
body = root.find('.//w:body', NS)
if body is None:
body = root # fallback if structure is different
table_data = {}
table_paragraphs = set()
table_index = 1
last_paragraphs = [] # Store the last few paragraphs (max 3)
# Iterate through direct children of the body
for elem in list(body):
if elem.tag == f'{{{NS["w"]}}}p':
# Keep a rolling list of the last 3 paragraphs
last_paragraphs.append(elem)
if len(last_paragraphs) > 3:
last_paragraphs.pop(0)
elif elem.tag == f'{{{NS["w"]}}}tbl':
# Look back through last_paragraphs for the most recent non-empty one
title = ""
for para in reversed(last_paragraphs):
texts = [t.text for t in para.findall('.//w:t', NS) if t.text]
candidate = ' '.join(texts).strip()
if candidate:
title = candidate
break
# If title contains 'template', skip this table
if title and 'template' in title.lower():
continue
rows = elem.findall('.//w:tr', NS)
if not rows:
continue # Skip empty tables
for paragraph in elem.findall('.//w:p', NS):
table_paragraphs.add(paragraph)
table_type, header_row_index = identify_table_type_and_header_row(rows)
if table_type == "unknown":
unknown_table_data = process_unknown_table(rows)
if unknown_table_data:
table_data[f"table_{table_index}_unknown"] = unknown_table_data
table_index += 1
continue
elif table_type == "buyer_seller":
buyer_seller_data = process_buyer_seller_table(rows[header_row_index:])
if buyer_seller_data:
table_data[f"table_{table_index}_buyer_seller"] = buyer_seller_data
table_index += 1
continue
elif table_type == "summary":
summary_data = process_summary_table(rows[header_row_index:])
if summary_data:
table_data[f"table_{table_index}_summary"] = summary_data
table_index += 1
continue
elif table_type == "long_table":
long_table_data = process_long_table(rows[header_row_index:])
if long_table_data:
table_data[f"long_table_{table_index}"] = long_table_data
table_index += 1
continue
return table_data, table_paragraphs
# --- Non-Table Processing Functions ---
def extract_text_outside_tables(root, table_paragraphs):
"""Extracts text from paragraphs outside tables in the document."""
extracted_text = []
for paragraph in root.findall('.//w:p', NS):
if paragraph in table_paragraphs:
continue # Skip paragraphs inside tables
texts = [t.text.strip() for t in paragraph.findall('.//w:t', NS) if t.text]
line = clean_spaces(' '.join(texts).replace(':',':')) # Clean colons and spaces
if ':' in line:
extracted_text.append(line)
return extracted_text
# --- Main Extraction Functions ---
def extract_docx_as_xml(file_bytes, save_xml=False, xml_filename="document.xml"):
# Ensure file_bytes is at the start position
file_bytes.seek(0)
with zipfile.ZipFile(file_bytes, 'r') as docx:
with docx.open('word/document.xml') as xml_file:
xml_content = xml_file.read().decode('utf-8')
if save_xml:
with open(xml_filename, "w", encoding="utf-8") as f:
f.write(xml_content)
return xml_content
def xml_to_json(xml_content, save_json=False, json_filename="extracted_data.json"):
tree = ET.ElementTree(ET.fromstring(xml_content))
root = tree.getroot()
table_data, table_paragraphs = extract_tables(root)
extracted_data = table_data
extracted_data["non_table_data"] = extract_text_outside_tables(root, table_paragraphs)
if save_json:
with open(json_filename, "w", encoding="utf-8") as f:
json.dump(extracted_data, f, ensure_ascii=False, indent=4)
return json.dumps(extracted_data, ensure_ascii=False, indent=4)
def extract_contract_summary(json_data, save_json=False, json_filename="contract_summary.json"):
"""Sends extracted JSON data to OpenAI and returns formatted structured JSON."""
# Step 1: Convert JSON string to Python dictionary
contract_data = json.loads(json_data)
# Step 2: Remove keys that contain "long_table"
filtered_contract_data = {key: value for key, value in contract_data.items() if "long_table" not in key}
# Step 3: Convert back to JSON string (if needed)
json_output = json.dumps(contract_data, ensure_ascii=False, indent=4)
# Define Pydantic model for contract summary validation
class ContractSummary(BaseModel):
合同编号: Optional[str] = ""
接收人: Optional[str] = ""
Recipient: Optional[str] = ""
接收地: Optional[str] = ""
Place_of_receipt: Optional[str] = Field("", alias="Place of receipt")
供应商: Optional[str] = ""
币种: Optional[str] = ""
供货日期: Optional[str] = ""
base_prompt = """You are given a contract in JSON format. Extract the following information:
# Response Format
Return the extracted information as a structured JSON in the exact format shown below (Note: Do not repeat any keys, if unsure leave the value empty):
{
"合同编号": 如果合同编号出现多次,只需填一个,不要重复,优先填写有"-"的合同编号
"接收人": (注意:不是买家必须是接收人,不是一个公司而是一个人)
"Recipient":
"接收地": (注意:不是交货地点是目的港,只写中文,英文写在 place of receipt)
"Place of receipt": (只写英文, 如果接收地/目的港/Port of destination 有英文可填在这里)
"供应商":
"币种": (主要用的货币,填英文缩写。GNF一般是为了方便而转换出来的, 除非只有GNF,GNF一般不是主要币种。)
"供货日期": (如果合同里有写才填,不要自己推理出日期,必须是一个日期,而不是天数)(格式:YYYY-MM-DD)
}
Contract data in JSON format:""" + f"""
{json_output}"""
messages = [
{
"role": "user",
"content": base_prompt
}
]
# Deepseek R1 Distilled Qwen 2.5 14B --------------------------------
client = OpenAI(
base_url=base_url,
api_key=HF_API_KEY,
)
# Try up to 3 times with error feedback
max_retries = 3
for attempt in range(max_retries):
try:
print(f"🔄 LLM attempt {attempt + 1} of {max_retries}")
completion = client.chat.completions.create(
model=model,
messages=messages,
temperature=0.1,
)
think_text = re.findall(r"<think>(.*?)</think>", completion.choices[0].message.content, flags=re.DOTALL)
if think_text:
print(f"🧠 Thought Process: {think_text}")
contract_summary = re.sub(r"<think>.*?</think>\s*", "", completion.choices[0].message.content, flags=re.DOTALL) # Remove think
contract_summary = re.sub(r"^```json\n|```$", "", contract_summary, flags=re.DOTALL) # Remove ```
# Clean up JSON before validation
contract_json = json.loads(contract_summary.strip())
# Clean 合同编号 by removing all contents in brackets including the brackets themselves
if "合同编号" in contract_json and contract_json["合同编号"]:
contract_json["合同编号"] = re.sub(r'[\((].*?[\))]', '', contract_json["合同编号"]).strip()
# Remove anything after "/" (including the "/" itself)
contract_json["合同编号"] = re.sub(r'/\s*.*$', '', contract_json["合同编号"]).strip()
validated_data = ContractSummary.model_validate(contract_json)
# Success! Return validated data
validated_json = json.dumps(validated_data.model_dump(by_alias=True), ensure_ascii=False, indent=4)
if save_json:
with open(json_filename, "w", encoding="utf-8") as f:
f.write(validated_json)
print(f"✅ Successfully validated contract summary on attempt {attempt + 1}")
return json.dumps(validated_json, ensure_ascii=False, indent=4)
except ValidationError as e:
error_msg = f"Validation error: {e}"
print(f"❌ {error_msg}")
except json.JSONDecodeError as e:
error_msg = f"JSON decode error: {e}"
print(f"❌ {error_msg}")
# Don't retry on the last attempt
if attempt < max_retries - 1:
# Add error message to the conversation and retry
messages.append({
"role": "assistant",
"content": completion.choices[0].message.content
})
messages.append({
"role": "user",
"content": f"Your response had the following error: {error_msg}. Please fix the format and provide a valid JSON response with the required fields."
})
# If we get here, all attempts failed - return empty but valid model
print("⚠️ All attempts failed, returning empty model")
empty_data = ContractSummary().model_dump(by_alias=True)
empty_json = json.dumps(empty_data, ensure_ascii=False, indent=4)
if save_json:
with open(json_filename, "w", encoding="utf-8") as f:
f.write(empty_json)
return json.dumps(empty_json, ensure_ascii=False, indent=4)
def find_price_list_table(extracted_data, min_matches=3):
price_keywords = [
"名称", "name", "规格", "specification", "型号", "model", "所属机型", "applicable models",
"单位", "unit", "数量", "quantity", "单价", "unit price", "总价", "amount",
"几郎单价", "unit price(gnf)", "几郎总价", "amount(gnf)", "备注", "remarks", "计划来源", "plan no",
"货描", "commodity",
]
last_price_list_table = None
last_price_list_key = None
# Get all long tables and sort them by key to ensure we process them in order
long_tables = [(key, table) for key, table in extracted_data.items()
if "long_table" in key and isinstance(table, list) and table]
long_tables.sort(key=lambda x: x[0]) # Sort by key to maintain order
for key, table in long_tables:
headers = list(table[0].keys())
match_count = 0
for header in headers:
header_lower = header.lower()
# Use fuzzy matching for keyword detection
for keyword in price_keywords:
if fuzz.partial_ratio(header_lower, keyword.lower()) >= 70:
match_count += 1
break # Found a match for this header, move to next
if match_count >= min_matches:
last_price_list_table = table # Keep the last table that meets criteria
last_price_list_key = key # Keep the key as well
return last_price_list_table, last_price_list_key
def extract_price_list(price_list, save_json=False, json_name="price_list.json", fuzzy=False):
"""
Extracts structured price list by first using hardcoded mapping, then falling back to AI if needed.
Set fuzzy=False to use direct string matching for mapping.
"""
# If price_list is empty, return an empty list
if not price_list:
return []
# Convert price_list to a list if it's a dict
if isinstance(price_list, dict):
# Check if the dict has any items
if len(price_list) == 0:
return []
# Convert to list if it's just a single entry dict
price_list = [price_list]
# Extract a sample row for header mapping
sample_row = price_list[0] if price_list else {}
# If there are no headers, return empty list
if not sample_row:
return []
# Get the headers directly from the sample row
extracted_headers = list(sample_row.keys())
# Clean double spaces in headers to facilitate matching
def clean_header_spaces(headers):
"""
Cleans headers for consistent matching by:
1. Normalizing multiple spaces to single space
2. Ensuring exactly one space between Chinese and English
"""
if not headers:
return headers
cleaned_headers = []
for header in headers:
if not header:
cleaned_headers.append(header)
continue
# Normalize multiple spaces to single space
header = re.sub(r'\s+', ' ', header)
# Ensure exactly one space between Chinese and English
header = re.sub(r'([\u4e00-\u9fff])\s*([a-zA-Z])', r'\1 \2', header)
header = re.sub(r'([a-zA-Z])\s*([\u4e00-\u9fff])', r'\1 \2', header)
# Final cleanup of any remaining multiple spaces
header = re.sub(r'\s+', ' ', header)
cleaned_headers.append(header.strip())
return cleaned_headers
# Define our target fields from the Pydantic model
target_fields = [
"序号", "名称", "名称(英文)", "品牌", "规格型号", "所属机型",
"数量", "单位", "单价", "总价", "几郎单价", "几郎总价",
"备注", "计划来源"
]
# Hardcoded mapping dictionary
hardcoded_mapping = {
# 序号 mappings
"序号": ["序号 no.", "序号 no", "no.", "no", "序号no.", "序号no", "序号 item", "序号item", "序号", "序号 no.:"],
# 名称 mappings
"名称": ["名称 name", "名称name", "name", "名称name of materials", "名称name of materials and equipment", "名称 name of materials", "名称 name of materials and equipment", "名称", "产品名称 product name", "货描", "commodity",],
# 名称(英文) mappings
"名称(英文)": ["名称 name", "名称name", "name", "名称name of materials", "名称name of materials and equipment", "名称 name of materials", "名称 name of materials and equipment", "名称", "产品名称 product name"],
# 品牌 mappings
"品牌": ["品牌 brand", "品牌brand", "brand", "品牌 brand", "品牌brand", "品牌"],
# 规格型号 mappings
"规格型号": ["规格型号 specification", "规格型号specification", "规格 specification", "规格specification",
"specification", "规格型号specification and model", "型号model", "型号 model", "规格型号 specification and model", "规格型号"],
# 所属机型 mappings
"所属机型": ["所属机型 applicable models", "所属机型applicable models", "applicable models", "所属机型"],
# 数量 mappings
"数量": ["数量 quantity", "数量quantity", "quantity", "qty", "数量qty", "数量"],
# 单位 mappings
"单位": ["单位 unit", "单位unit", "unit", "单位"],
# 单价 mappings
"单价": ["单价 unit price (cny)", "单价unit price (cny)", "单价(元)Unit Price (CNY)", "unit price (cny)", "单价unit price", "单价 unit price", "单价 unit price(cny)",
"单价(元)", "单价(cny)", "单价 unit price (cny)", "单价(欧元) unit price(eur)", "单价", "单价(元) unit price(cny)", "单价(元)unit price(cny)", "单价(欧元) unit price(eur)",
"价格 price", "价格price", "价格",
"美元单价"],
# 总价 mappings
"总价": ["总价 total amount (cny)", "总价total amount (cny)", "total amount (cny)", "总价total amount", "总价 total amount",
"总价(元)", "总额(元)", "总价 total amount (cny)", "总价(欧元) amount(eur)", "总价", "总价(元)amount (cny)", "总价(元)amount(cny)",
"总额 total amount (cny)", "总额", "总额 total amount","美元总价"],
# 几郎单价 mappings
"几郎单价": ["几郎单价 unit price (gnf)", "几郎单价unit price (gnf)", "unit price (gnf)", "几郎单价unit price", "几郎单价 unit price",
"几郎单价(元)", "单价(几郎)","单价 unit price (gnf)", "几郎单价 unit price (gnf)", "几郎单价", "单价 unit price(几郎)(gnf)", "单价(元)unit price(cny)", "几郎单价 unit price(gnf)"],
# 几郎总价 mappings
"几郎总价": ["几郎总价 total amount (gnf)", "几郎总价total amount (gnf)", "total amount (gnf)", "几郎总价total amount", "几郎总价 total amount",
"几郎总价(元)", "总额(几郎)", "几郎总价 total amount (gnf)", "几郎总价", "总额 total amount(几郎)(gnf)", "总价(元)amount(cny)", "几郎总价 amount(gnf)","总额 total amount (gnf)"],
# 备注 mappings
"备注": ["备注 remarks", "备注remarks", "remarks", "备注 notes", "备注notes", "note", "备注"],
# 计划来源 mappings
"计划来源": ["计划来源 plan no.", "计划来源plan no.", "计划来源(唛头信息)",
"计划来源 planned source", "计划来源planned source", "planned source", "计划来源","计划号 plan no."]
}
# Clean the extracted headers first
cleaned_extracted_headers = clean_header_spaces(extracted_headers)
# Clean all possible headers in the hardcoded mapping
cleaned_hardcoded_mapping = {
std_field: [clean_header_spaces([h])[0] for h in possible_headers]
for std_field, possible_headers in hardcoded_mapping.items()
}
# Fuzzy matching function
def fuzzy_match_header(header, possible_headers, threshold=70):
if not possible_headers:
return None, 0
best_match = process.extractOne(header, possible_headers, scorer=fuzz.ratio)
if best_match and best_match[1] >= threshold:
return best_match[0], best_match[1]
else:
return None, 0
# Try to map headers using hardcoded mapping (fuzzy or direct)
standard_field_mapping = {}
unmapped_headers = []
if fuzzy:
print("\n🔍 Fuzzy Hardcoded Mapping Results:")
else:
print("\n🔍 Direct Hardcoded Mapping Results:")
print("-" * 50)
for header in cleaned_extracted_headers:
header_mapped = False
if fuzzy:
best_match_score = 0
best_match_field = None
best_match_header = None
for std_field, possible_headers in cleaned_hardcoded_mapping.items():
if std_field in standard_field_mapping:
continue
matched_header, score = fuzzy_match_header(header, possible_headers, threshold=70)
if matched_header and score > best_match_score:
best_match_score = score
best_match_field = std_field
best_match_header = matched_header
if best_match_field and best_match_score >= 70:
standard_field_mapping[best_match_field] = header
header_mapped = True
print(f"✅ {best_match_field} -> {header} (score: {best_match_score})")
else:
for std_field, possible_headers in cleaned_hardcoded_mapping.items():
if std_field in standard_field_mapping:
continue
if header in possible_headers:
standard_field_mapping[std_field] = header
header_mapped = True
print(f"✅ {std_field} -> {header}")
break
if not header_mapped:
unmapped_headers.append(header)
print(f"❌ No match found for: {header}")
print("-" * 50)
# If we have unmapped headers, fall back to AI mapping
if unmapped_headers:
print(f"⚠️ Some headers could not be mapped using hardcoded mapping: {unmapped_headers}")
print("🔄 Falling back to AI mapping...")
# Get the list of standard fields that haven't been mapped yet
unmapped_standard_fields = [field for field in target_fields if field not in standard_field_mapping]
# Use AI to map remaining headers
base_prompt = f"""
You are playing a matching game. Match each and every standard fields to the exact column headers within "" separated by ,.
You must match all the given column headers to the standard fields to you best ability.
USE THE EXACT HEADER BELOW INCLUDING BOTH CHINESE AND ENGLISH AND THE EXACT SPACING.
The standard fields that need mapping are:
{json.dumps(unmapped_standard_fields, ensure_ascii=False)}
You are given column headers below: (YOU MUST USE THE EXACT HEADER BELOW INCLUDING BOTH CHINESE AND ENGLISH AND THE EXACT SPACING)
{json.dumps(unmapped_headers, ensure_ascii=False)}
ENSURE ALL STANDARD FIELDS ARE MAPPED TO THE EXACT COLUMN HEADER INCLUDING BOTH CHINESE AND ENGLISH AND THE EXACT SPACING.
Return only a JSON mapping in this format WITHOUT any explanations:
```json
{{
"standard_field_1": "column_header_1",
"standard_field_2": "column_header_2",
...
}}
```
Common mistakes to note:
Do not force map 名称(英文) to 单价
"""
messages = [{"role": "user", "content": base_prompt}]
client = OpenAI(
base_url=base_url,
api_key=HF_API_KEY,
)
# Add retry logic for AI mapping
max_retries = 3
for attempt in range(max_retries):
try:
print(f"🔄 Sending prompt to LLM (attempt {attempt + 1} of {max_retries})")
response = client.chat.completions.create(
model=model,
messages=messages,
temperature=0.1,
)
raw_mapping = response.choices[0].message.content
think_text = re.findall(r"<think>(.*?)</think>", response.choices[0].message.content, flags=re.DOTALL)
if think_text:
print(f"🧠 Thought Process: {think_text}")
raw_mapping = re.sub(r"<think>.*?</think>\s*", "", raw_mapping, flags=re.DOTALL) # Remove think
# Remove any backticks or json tags
raw_mapping = re.sub(r"```json|```", "", raw_mapping)
# Parse the AI mapping and merge with hardcoded mapping
ai_mapping = json.loads(raw_mapping.strip())
standard_field_mapping.update(ai_mapping)
# Check if all standard fields are mapped
still_unmapped = [field for field in target_fields if field not in standard_field_mapping]
if still_unmapped:
print(f"⚠️ Some standard fields are still unmapped: {still_unmapped}")
if attempt < max_retries - 1:
# Add feedback to the prompt for the next attempt
messages.append({
"role": "assistant",
"content": response.choices[0].message.content
})
messages.append({
"role": "user",
"content": f"The following standard fields are still unmapped: {still_unmapped}. Please try to map these fields using the available headers: {unmapped_headers}"
})
continue
else:
print(f"✅ Successfully mapped all fields using AI")
print("\n📊 AI Mapping Results:")
print("-------------------")
for std_field, mapped_header in ai_mapping.items():
print(f"{std_field} -> {mapped_header}")
print("-------------------")
break
except Exception as e:
error_msg = f"Error in AI mapping attempt {attempt + 1}: {e}"
print(f"❌ {error_msg}")
if attempt < max_retries - 1:
messages.append({
"role": "assistant",
"content": response.choices[0].message.content
})
messages.append({
"role": "user",
"content": f"Your response had the following error: {error_msg}. Please fix your mapping and try again."
})
else:
print(f"⚠️ All AI mapping attempts failed, proceeding with partial mapping")
# After all mapping is done, print the final mapping and unmapped columns
print("\n📊 Final Field Mapping:")
print("-" * 50)
# Print all standard fields, showing mapping if exists or blank if not
for field in target_fields:
mapped_header = standard_field_mapping.get(field, "")
print(f"{field} -> {mapped_header}")
print("-" * 50)
# Check for unmapped standard fields
unmapped_standard = [field for field in target_fields if field not in standard_field_mapping]
if unmapped_standard:
print("\n⚠️ Unmapped Standard Fields:")
print("-" * 50)
for field in unmapped_standard:
print(f"- {field}")
print("-" * 50)
# Check for unmapped extracted headers
mapped_headers = set(standard_field_mapping.values())
unmapped_headers = [header for header in extracted_headers if header not in mapped_headers]
if unmapped_headers:
print("\n⚠️ Unmapped Extracted Headers:")
print("-" * 50)
for header in unmapped_headers:
print(f"- {header}")
print("-" * 50)
# Function to separate Chinese and English text
def separate_chinese_english(text):
if not text or not isinstance(text, str):
return "", ""
# Find all Chinese character positions
chinese_positions = []
for i, char in enumerate(text):
if '\u4e00' <= char <= '\u9fff':
chinese_positions.append(i)
if not chinese_positions:
# No Chinese characters, return empty Chinese and full text as English
return "", text.strip()
# Find the last Chinese character position
last_chinese_pos = chinese_positions[-1]
# Look for the best split point that preserves brackets and punctuation
split_pos = last_chinese_pos + 1
# Check if there are brackets or parentheses that should be kept together
# Look ahead to see if there are closing brackets that belong to the Chinese part
remaining_text = text[split_pos:]
# If the remaining text starts with closing brackets/parentheses, include them in the Chinese part
# This handles both Chinese brackets () and English brackets () that belong to Chinese text
if remaining_text:
# Check for closing brackets that should stay with Chinese
# Use raw string to avoid escape sequence warning
closing_brackets = ')】」』》〉""''()]'
if remaining_text[0] in closing_brackets:
# Find how many closing brackets we have
bracket_count = 0
for char in remaining_text:
if char in closing_brackets:
bracket_count += 1
else:
break
split_pos += bracket_count
# Everything up to the split point is Chinese
chinese_part = text[:split_pos].strip()
# Everything after the split point is English
english_part = text[split_pos:].strip()
# Clean up the parts
# Remove any trailing Chinese punctuation from English part if it doesn't make sense
if english_part:
# If English part starts with Chinese punctuation that doesn't belong, move it to Chinese
chinese_punct_start = re.match(r'^[、,。;:!?]+', english_part)
if chinese_punct_start:
chinese_part += chinese_punct_start.group()
english_part = english_part[len(chinese_punct_start.group()):].strip()
# If English part doesn't actually contain English letters, treat it as empty
if not re.search(r'[a-zA-Z]', english_part):
english_part = ""
return chinese_part, english_part
# Process the data based on the final mapping
transformed_data = []
for row in price_list:
new_row = {field: "" for field in target_fields} # Initialize with empty strings
other_fields = {}
# Step 1: Handle name fields first - look for any field with "名称" or "name"
for header, value in row.items():
# Skip if header is None
if header is None:
continue
# Clean the header for comparison
cleaned_header = re.sub(r'\s+', ' ', str(header)).strip()
header_lower = cleaned_header.lower()
if ("名称" in header_lower or "name" in header_lower) and value:
# If field contains both Chinese and English, separate them
if re.search(r'[\u4e00-\u9fff]', str(value)) and re.search(r'[a-zA-Z]', str(value)):
chinese, english = separate_chinese_english(str(value))
if chinese:
new_row["名称"] = chinese
if english:
new_row["名称(英文)"] = english
# print(f"Separated: '{value}' → Chinese: '{chinese}', English: '{english}'")
else:
# Just set the name directly
new_row["名称"] = str(value)
break # Stop after finding first name field
# Step 2: Fill in all other fields using standard mapping
for header, value in row.items():
# Skip if header is None
if header is None:
continue
# Skip empty values
if not value:
continue
# Clean the header for comparison
cleaned_header = re.sub(r'\s+', ' ', str(header)).strip()
# Check if this maps to a standard field using fuzzy matching
matched_field = None
best_match_score = 0
for std_field, mapped_header in standard_field_mapping.items():
# Skip if mapped_header is None
if mapped_header is None:
continue
# Use fuzzy matching for more flexible comparison
score = fuzz.ratio(cleaned_header.lower().strip(), mapped_header.lower().strip())
if score > best_match_score and score >= 80: # High threshold for data processing
best_match_score = score
matched_field = std_field
# If we found a mapping, use it (but don't overwrite name fields)
if matched_field:
if matched_field not in ["名称", "名称(英文)"] or not new_row[matched_field]:
new_row[matched_field] = str(value)
# If no mapping found, add to other_fields
else:
# Skip name fields we already processed
header_lower = cleaned_header.lower()
if not ("名称" in header_lower or "name" in header_lower):
other_fields[header] = str(value)
# Add remaining fields to "其他"
if other_fields:
new_row["其他"] = other_fields
else:
new_row["其他"] = {}
# Convert field names for validation
if "名称(英文)" in new_row:
new_row["名称(英文)"] = new_row.pop("名称(英文)")
transformed_data.append(new_row)
# Save to file if requested
if save_json and transformed_data:
# Handle edge cases before saving
transformed_data = handle_edge_cases(transformed_data)
with open(json_name, "w", encoding="utf-8") as f:
json.dump(transformed_data, f, ensure_ascii=False, indent=4)
print(f"✅ Saved to {json_name}")
# Handle edge cases (including duplicate column merging) before returning
transformed_data = handle_edge_cases(transformed_data)
return transformed_data
def json_to_excel(contract_summary, json_data, excel_path):
"""Converts extracted JSON tables to an Excel file."""
# Correctly parse the JSON string
contract_summary_json = json.loads(json.loads(contract_summary))
contract_summary_df = pd.DataFrame([contract_summary_json])
# Ensure json_data is a dictionary
if isinstance(json_data, str):
json_data = json.loads(json_data)
long_tables = [pd.DataFrame(table) for key, table in json_data.items() if "long_table" in key and "summary" not in key]
long_table = long_tables[-1] if long_tables else pd.DataFrame()
with pd.ExcelWriter(excel_path) as writer:
contract_summary_df.to_excel(writer, sheet_name="Contract Summary", index=False)
long_table.to_excel(writer, sheet_name="Price List", index=False)
#--- Handle Edge Cases ------------------------------
def handle_weight_conversion_edge_case(transformed_data):
"""
Handles the edge case where converted weight is in '其他' field.
If found, replaces quantity and unit with the converted weight values.
Extracts unit from the bracket in the column header.
"""
for row in transformed_data:
if "其他" not in row or not isinstance(row["其他"], dict):
continue
other_fields = row["其他"]
# Look for weight conversion column with various possible names
weight_key = None
weight_patterns = [
r"换算重量(吨)",
r"converted weight(t)",
r"换算重量",
r"converted weight",
r"重量换算",
r"weight conversion"
]
for key in other_fields:
# Check if any pattern is contained within the key
if any(re.search(pattern, key, re.IGNORECASE) for pattern in weight_patterns):
weight_key = key
break
if weight_key and other_fields[weight_key]:
try:
# Try to convert to float to ensure it's a valid number
weight_value = float(other_fields[weight_key])
# Only replace if the weight value is valid
if weight_value > 0:
# Store original values in case we need to revert
original_quantity = row.get("数量", "")
original_unit = row.get("单位", "")
# Extract unit from the bracket in the column header
unit = "吨" # default unit
bracket_match = re.search(r'[((]([^))]+)[))]', weight_key)
if bracket_match:
unit = bracket_match.group(1).strip()
# Clean up the unit (remove any extra text)
unit = re.sub(r'[^a-zA-Z\u4e00-\u9fff]', '', unit)
# Replace with converted weight
row["数量"] = str(weight_value)
row["单位"] = unit
# Log the conversion
print(f"Converted weight: {weight_value}{unit} (original: {original_quantity} {original_unit})")
# Remove the weight field from other_fields
del other_fields[weight_key]
except (ValueError, TypeError):
# If conversion fails, log and skip
print(f"Warning: Invalid weight value '{other_fields[weight_key]}' in row")
continue
return transformed_data
def handle_edge_cases(transformed_data):
"""
Main function to handle all edge cases in the transformed data.
Currently handles:
1. Weight conversion from '其他' field
2. Duplicate column merging
"""
# Handle weight conversion edge case
transformed_data = handle_weight_conversion_edge_case(transformed_data)
# Handle duplicate column merging
transformed_data = merge_duplicate_columns(transformed_data)
return transformed_data
def merge_duplicate_columns(transformed_data):
"""
Removes duplicate columns that were created due to column spanning in headers.
Simply deletes columns with names ending in _2, _3, etc.
"""
if not transformed_data:
return transformed_data
# Find all duplicate columns (ending with _number)
duplicate_columns = set()
for row in transformed_data:
for column in row.keys():
# Check if this is a duplicate column (ends with _number)
if re.match(r'^.+_\d+$', column):
duplicate_columns.add(column)
# Remove all duplicate columns from all rows
if duplicate_columns:
print(f"🗑️ Removing duplicate columns: {sorted(duplicate_columns)}")
for row in transformed_data:
for dup_col in duplicate_columns:
if dup_col in row:
del row[dup_col]
return transformed_data
#--- Extract PO ------------------------------
def extract_po(docx_path):
"""Processes a single .docx file, extracts tables, formats with OpenAI, and returns combined JSON data."""
if not os.path.exists(docx_path) or not docx_path.endswith(".docx"):
raise ValueError(f"Invalid file: {docx_path}")
# Read the .docx file as bytes
with open(docx_path, "rb") as f:
docx_bytes = BytesIO(f.read())
try:
# Step 1: Extract XML content from DOCX
print("Extracting Docs data to XML...")
xml_filename = os.path.splitext(os.path.basename(docx_path))[0] + "_document.xml"
try:
xml_file = extract_docx_as_xml(docx_bytes, save_xml=False, xml_filename=xml_filename)
get_namespace(ET.fromstring(xml_file))
except (zipfile.BadZipFile, KeyError):
raise ValueError(f"Invalid file: {docx_path}")
# Step 2: Extract tables from DOCX and save JSON
print("Extracting XML data to JSON...")
json_filename = os.path.splitext(os.path.basename(docx_path))[0] + "_extracted_data.json"
extracted_data = xml_to_json(xml_file, save_json=False, json_filename=json_filename)
# Find and rename the price list table before contract summary processing
print("Identifying Price List table...")
extracted_data_dict = json.loads(extracted_data)
price_list_table, price_list_key = find_price_list_table(extracted_data_dict)
# Add the combined price list table to the extracted data
if price_list_table:
# Remove only the specific long_table that was used to create the price list
if price_list_key:
del extracted_data_dict[price_list_key]
# Add the combined price list table
extracted_data_dict["price_list"] = price_list_table
# Update the extracted_data string with proper formatting
extracted_data = json.dumps(extracted_data_dict, ensure_ascii=False, indent=4)
else:
print("⚠️ No suitable price list table found!")
extracted_data_dict["price_list"] = []
extracted_data = json.dumps(extracted_data_dict, ensure_ascii=False, indent=4)
# print(f"✅ Extracted Data: {extracted_data}")
# Create a copy of the data with only first row of price list for contract summary
contract_summary_dict = json.loads(extracted_data)
if contract_summary_dict.get("price_list"):
contract_summary_dict["price_list"] = [contract_summary_dict["price_list"][0]] if contract_summary_dict["price_list"] else []
contract_summary_data = json.dumps(contract_summary_dict, ensure_ascii=False, indent=4)
print(f"✅ Contract Summary Data: {contract_summary_data}")
# Step 3: Process JSON with OpenAI to get structured output
print("Processing Contract Summary data with AI...")
contract_summary_filename = os.path.splitext(os.path.basename(docx_path))[0] + "_contract_summary.json"
contract_summary = extract_contract_summary(contract_summary_data, save_json=False, json_filename=contract_summary_filename)
# Process the price list
print("Processing Price List data with AI...")
price_list_filename = os.path.join(os.path.dirname(docx_path), os.path.splitext(os.path.basename(docx_path))[0] + "_price_list.json")
price_list = extract_price_list(price_list_table, save_json=False, json_name=price_list_filename, fuzzy=True)
# Step 4: Combine contract summary and long table data into a single JSON object
print("Combining AI Generated JSON with Extracted Data...")
combined_data = {
"contract_summary": json.loads(json.loads(contract_summary)),
"price_list": price_list
}
return combined_data
finally:
# Ensure BytesIO is properly closed
if 'docx_bytes' in locals():
docx_bytes.close()
# Example Usage
# print(extract_po("test-contracts\GN-SMB268202501-042WJ SMB268波纹管采购合同-东营顺航.docx"))
# print(extract_po(r"UAT Contracts\20250703\GN-WAPJS202405-297HG 1200R20轮胎采购合同-威海君乐-法务审批0515.docx"))
# print(extract_price_list([{'序号 No.': '1', '名称 Name': 'PE波纹管(双壁波纹管) PE corrugated pipe (double wall corrugated pipe)', '规格 Specification': '内径600mm,6米/根,SN8 Inner diameter 600mm, 6 meters per piece, SN8', '单位 Unit': '米m', '数量 Quantity': '180', '单价(元) Unit Price (CNY)': '106.00', '总额(元) Total Amount (CNY)': '1080.00', '几郎单价(元) Unit Price (GNF)': '16.21', '几郎总额(元) Total Amount (GNF)': '22118.38', '品牌 Brand': '鹏洲PZ', '计划来源 Planned Source': 'SMB268-GNHY-0021-WJ-20250108'}]))
# Gradio Interface ------------------------------
import gradio as gr
from gradio.themes.base import Base
interface = gr.Interface(
fn=extract_po,
title="PO Extractor 买卖合同数据提取",
inputs=gr.File(label="买卖合同 (.docx)"),
outputs=gr.Json(label="提取结果"),
flagging_mode="never",
theme=Base()
)
interface.launch(show_error=True)