Sandy2636
Add application file
1705f76
import gradio as gr
import base64
import requests
import json
import re
import os
import uuid
from datetime import datetime
# --- Configuration ---
# IMPORTANT: Set your OPENROUTER_API_KEY as a Hugging Face Space Secret
OPENROUTER_API_KEY = "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"
IMAGE_MODEL = "opengvlab/internvl3-14b:free" # Using the free tier model as specified
OPENROUTER_API_URL = "https://openrouter.ai/api/v1/chat/completions"
# --- Global State (managed within Gradio's session if possible, or module-level for simplicity here) ---
# This will be reset each time the processing function is called.
# For a multi-user or more robust app, session state or a proper backend DB would be needed.
processed_files_data = [] # Stores dicts for each file's details and status
person_profiles = {} # Stores dicts for each identified person and their documents
# --- Helper Functions ---
def extract_json_from_text(text):
"""
Extracts a JSON object from a string, trying common markdown and direct JSON.
"""
if not text:
return {"error": "Empty text provided for JSON extraction."}
# Try to match ```json ... ``` code block
match_block = re.search(r"```json\s*(\{.*?\})\s*```", text, re.DOTALL | re.IGNORECASE)
if match_block:
json_str = match_block.group(1)
else:
# If no block, assume the text itself might be JSON or wrapped in single backticks
text_stripped = text.strip()
if text_stripped.startswith("`") and text_stripped.endswith("`"):
json_str = text_stripped[1:-1]
else:
json_str = text_stripped # Assume it's direct JSON
try:
return json.loads(json_str)
except json.JSONDecodeError as e:
# Fallback: Try to find the first '{' and last '}' if initial parsing fails
try:
first_brace = json_str.find('{')
last_brace = json_str.rfind('}')
if first_brace != -1 and last_brace != -1 and last_brace > first_brace:
potential_json_str = json_str[first_brace : last_brace+1]
return json.loads(potential_json_str)
else:
return {"error": f"Invalid JSON structure: {str(e)}", "original_text": text}
except json.JSONDecodeError as e2:
return {"error": f"Invalid JSON structure after attempting substring: {str(e2)}", "original_text": text}
def get_ocr_prompt():
return f"""You are an advanced OCR and information extraction AI.
Your task is to meticulously analyze this image and extract all relevant information.
Output Format Instructions:
Provide your response as a SINGLE, VALID JSON OBJECT. Do not include any explanatory text before or after the JSON.
The JSON object should have the following top-level keys:
- "document_type_detected": (string) Your best guess of the specific document type (e.g., "Passport", "National ID Card", "Driver's License", "Visa Sticker", "Hotel Confirmation Voucher", "Bank Statement", "Photo of a person").
- "extracted_fields": (object) A key-value map of all extracted information. Be comprehensive. Examples:
- For passports/IDs: "Surname", "Given Names", "Full Name", "Document Number", "Nationality", "Date of Birth", "Sex", "Place of Birth", "Date of Issue", "Date of Expiry", "Issuing Authority", "Country Code".
- For hotel reservations: "Guest Name", "Hotel Name", "Booking Reference", "Check-in Date", "Check-out Date".
- For bank statements: "Account Holder Name", "Account Number", "Bank Name", "Statement Period", "Ending Balance".
- For photos: "Description" (e.g., "Portrait of a person", "Group photo at a location"), "People Present" (array of strings if multiple).
- "mrz_data": (object or null) If a Machine Readable Zone (MRZ) is present:
- "raw_mrz_lines": (array of strings) Each line of the MRZ.
- "parsed_mrz": (object) Key-value pairs of parsed MRZ fields.
If no MRZ, this field should be null.
- "full_text_ocr": (string) Concatenation of all text found on the document.
Extraction Guidelines:
1. Prioritize accuracy.
2. Extract all visible text. Include "Full Name" by combining given and surnames if possible.
3. For dates, try to use ISO 8601 format (YYYY-MM-DD) if possible, but retain original format if conversion is ambiguous.
Ensure the entire output strictly adheres to the JSON format.
"""
def call_openrouter_ocr(image_filepath):
if not OPENROUTER_API_KEY:
return {"error": "OpenRouter API Key not configured."}
try:
with open(image_filepath, "rb") as f:
encoded_image = base64.b64encode(f.read()).decode("utf-8")
# Basic MIME type guessing, default to jpeg
mime_type = "image/jpeg"
if image_filepath.lower().endswith(".png"):
mime_type = "image/png"
elif image_filepath.lower().endswith(".webp"):
mime_type = "image/webp"
data_url = f"data:{mime_type};base64,{encoded_image}"
prompt_text = get_ocr_prompt()
payload = {
"model": IMAGE_MODEL,
"messages": [
{
"role": "user",
"content": [
{"type": "text", "text": prompt_text},
{"type": "image_url", "image_url": {"url": data_url}}
]
}
],
"max_tokens": 3500, # Increased for detailed JSON
"temperature": 0.1,
}
headers = {
"Authorization": f"Bearer {OPENROUTER_API_KEY}",
"Content-Type": "application/json",
"HTTP-Referer": "https://huggingface.co/spaces/DoClassifier", # Optional: Update with your Space URL
"X-Title": "DoClassifier Processor" # Optional
}
response = requests.post(OPENROUTER_API_URL, headers=headers, json=payload, timeout=180) # 3 min timeout
response.raise_for_status()
result = response.json()
if "choices" in result and result["choices"]:
raw_content = result["choices"][0]["message"]["content"]
return extract_json_from_text(raw_content)
else:
return {"error": "No 'choices' in API response from OpenRouter.", "details": result}
except requests.exceptions.Timeout:
return {"error": "API request timed out."}
except requests.exceptions.RequestException as e:
error_message = f"API Request Error: {str(e)}"
if hasattr(e, 'response') and e.response is not None:
error_message += f" Status: {e.response.status_code}, Response: {e.response.text}"
return {"error": error_message}
except Exception as e:
return {"error": f"An unexpected error occurred during OCR: {str(e)}"}
def extract_entities_from_ocr(ocr_json):
if not ocr_json or "extracted_fields" not in ocr_json or not isinstance(ocr_json["extracted_fields"], dict):
return {"name": None, "dob": None, "passport_no": None, "doc_type": ocr_json.get("document_type_detected", "Unknown")}
fields = ocr_json["extracted_fields"]
doc_type = ocr_json.get("document_type_detected", "Unknown")
# Normalize potential field names (case-insensitive search)
name_keys = ["full name", "name", "account holder name", "guest name"]
dob_keys = ["date of birth", "dob"]
passport_keys = ["document number", "passport number"]
extracted_name = None
for key in name_keys:
for field_key, value in fields.items():
if key == field_key.lower():
extracted_name = str(value) if value else None
break
if extracted_name:
break
extracted_dob = None
for key in dob_keys:
for field_key, value in fields.items():
if key == field_key.lower():
extracted_dob = str(value) if value else None
break
if extracted_dob:
break
extracted_passport_no = None
for key in passport_keys:
for field_key, value in fields.items():
if key == field_key.lower():
extracted_passport_no = str(value).replace(" ", "").upper() if value else None # Normalize
break
if extracted_passport_no:
break
return {
"name": extracted_name,
"dob": extracted_dob,
"passport_no": extracted_passport_no,
"doc_type": doc_type
}
def normalize_name(name):
if not name: return ""
return "".join(filter(str.isalnum, name)).lower()
def get_person_id_and_update_profiles(doc_id, entities, current_persons_data):
"""
Tries to assign a document to an existing person or creates a new one.
Returns a person_key.
Updates current_persons_data in place.
"""
passport_no = entities.get("passport_no")
name = entities.get("name")
dob = entities.get("dob")
# 1. Match by Passport Number (strongest identifier)
if passport_no:
for p_key, p_data in current_persons_data.items():
if passport_no in p_data.get("passport_numbers", set()):
p_data["doc_ids"].add(doc_id)
# Update person profile with potentially new name/dob if current is missing
if name and not p_data.get("canonical_name"): p_data["canonical_name"] = name
if dob and not p_data.get("canonical_dob"): p_data["canonical_dob"] = dob
return p_key
# New person based on passport number
new_person_key = f"person_{passport_no}" # Or more robust ID generation
current_persons_data[new_person_key] = {
"canonical_name": name,
"canonical_dob": dob,
"names": {normalize_name(name)} if name else set(),
"dobs": {dob} if dob else set(),
"passport_numbers": {passport_no},
"doc_ids": {doc_id},
"display_name": name or f"Person (ID: {passport_no})"
}
return new_person_key
# 2. Match by Normalized Name + DOB (if passport not found or not present)
if name and dob:
norm_name = normalize_name(name)
composite_key_nd = f"{norm_name}_{dob}"
for p_key, p_data in current_persons_data.items():
# Check if this name and dob combo has been seen for this person
if norm_name in p_data.get("names", set()) and dob in p_data.get("dobs", set()):
p_data["doc_ids"].add(doc_id)
return p_key
# New person based on name and DOB
new_person_key = f"person_{composite_key_nd}_{str(uuid.uuid4())[:4]}"
current_persons_data[new_person_key] = {
"canonical_name": name,
"canonical_dob": dob,
"names": {norm_name},
"dobs": {dob},
"passport_numbers": set(),
"doc_ids": {doc_id},
"display_name": name
}
return new_person_key
# 3. If only name, less reliable, create new person (could add fuzzy matching later)
if name:
norm_name = normalize_name(name)
# Check if a person with just this name exists and has no other strong identifiers yet
# This part can be made more robust, for now, it might create more splits
new_person_key = f"person_{norm_name}_{str(uuid.uuid4())[:4]}"
current_persons_data[new_person_key] = {
"canonical_name": name, "canonical_dob": None,
"names": {norm_name}, "dobs": set(), "passport_numbers": set(),
"doc_ids": {doc_id}, "display_name": name
}
return new_person_key
# 4. Unclassifiable for now, assign a generic unique person key
generic_person_key = f"unidentified_person_{str(uuid.uuid4())[:6]}"
current_persons_data[generic_person_key] = {
"canonical_name": "Unknown", "canonical_dob": None,
"names": set(), "dobs": set(), "passport_numbers": set(),
"doc_ids": {doc_id}, "display_name": f"Unknown Person ({doc_id[:6]})"
}
return generic_person_key
def format_dataframe_data(current_files_data):
# Headers for the dataframe
# "ID", "Filename", "Status", "Detected Type", "Extracted Name", "Extracted DOB", "Main ID", "Person Key"
df_rows = []
for f_data in current_files_data:
entities = f_data.get("entities") or {}
df_rows.append([
f_data["doc_id"][:8], # Short ID
f_data["filename"],
f_data["status"],
entities.get("doc_type", "N/A"),
entities.get("name", "N/A"),
entities.get("dob", "N/A"),
entities.get("passport_no", "N/A"),
f_data.get("assigned_person_key", "N/A")
])
return df_rows
def format_persons_markdown(current_persons_data, current_files_data):
if not current_persons_data:
return "No persons identified yet."
md_parts = ["## Classified Persons & Documents\n"]
for p_key, p_data in current_persons_data.items():
display_name = p_data.get('display_name', p_key)
md_parts.append(f"### Person: {display_name} (Profile Key: {p_key})")
if p_data.get("canonical_dob"): md_parts.append(f"* DOB: {p_data['canonical_dob']}")
if p_data.get("passport_numbers"): md_parts.append(f"* Passport(s): {', '.join(p_data['passport_numbers'])}")
md_parts.append("* Documents:")
doc_ids_for_person = p_data.get("doc_ids", set())
if doc_ids_for_person:
for doc_id in doc_ids_for_person:
# Find the filename and detected type from current_files_data
doc_detail = next((f for f in current_files_data if f["doc_id"] == doc_id), None)
if doc_detail:
filename = doc_detail["filename"]
doc_type = doc_detail.get("entities", {}).get("doc_type", "Unknown Type")
md_parts.append(f" - {filename} (`{doc_type}`)")
else:
md_parts.append(f" - Document ID: {doc_id[:8]} (details not found, unexpected)")
else:
md_parts.append(" - No documents currently assigned.")
md_parts.append("\n---\n")
return "\n".join(md_parts)
# --- Main Gradio Processing Function (Generator) ---
def process_uploaded_files(files_list, progress=gr.Progress(track_tqdm=True)):
global processed_files_data, person_profiles # Reset global state for each run
processed_files_data = []
person_profiles = {}
if not OPENROUTER_API_KEY:
yield (
[["N/A", "ERROR", "OpenRouter API Key not configured.", "N/A", "N/A", "N/A", "N/A", "N/A"]],
"Error: OpenRouter API Key not configured. Please set it in Space Secrets.",
"{}", "API Key Missing. Processing halted."
)
return
if not files_list:
yield ([], "No files uploaded.", "{}", "Upload files to begin.")
return
# Initialize processed_files_data
for i, file_obj in enumerate(files_list):
doc_uid = str(uuid.uuid4())
processed_files_data.append({
"doc_id": doc_uid,
"filename": os.path.basename(file_obj.name), # file_obj.name is the temp path
"filepath": file_obj.name,
"status": "Queued",
"ocr_json": None,
"entities": None,
"assigned_person_key": None
})
initial_df_data = format_dataframe_data(processed_files_data)
initial_persons_md = format_persons_markdown(person_profiles, processed_files_data)
yield (initial_df_data, initial_persons_md, "{}", f"Initialized. Found {len(files_list)} files.")
# Iterate and process each file
for i, file_data_item in enumerate(progress.tqdm(processed_files_data, desc="Processing Documents")):
current_doc_id = file_data_item["doc_id"]
current_filename = file_data_item["filename"]
# 1. OCR Processing
file_data_item["status"] = "OCR in Progress..."
df_data = format_dataframe_data(processed_files_data)
persons_md = format_persons_markdown(person_profiles, processed_files_data) # No change yet
yield (df_data, persons_md, "{}", f"({i+1}/{len(processed_files_data)}) OCR for: {current_filename}")
ocr_result = call_openrouter_ocr(file_data_item["filepath"])
file_data_item["ocr_json"] = ocr_result # Store full JSON
if "error" in ocr_result:
file_data_item["status"] = f"OCR Error: {ocr_result['error'][:50]}..." # Truncate long errors
df_data = format_dataframe_data(processed_files_data)
yield (df_data, persons_md, json.dumps(ocr_result, indent=2), f"({i+1}/{len(processed_files_data)}) OCR Error on {current_filename}")
continue # Move to next file
file_data_item["status"] = "OCR Done. Extracting Entities..."
df_data = format_dataframe_data(processed_files_data)
yield (df_data, persons_md, json.dumps(ocr_result, indent=2), f"({i+1}/{len(processed_files_data)}) OCR Done for {current_filename}")
# 2. Entity Extraction
entities = extract_entities_from_ocr(ocr_result)
file_data_item["entities"] = entities
file_data_item["status"] = "Entities Extracted. Classifying..."
df_data = format_dataframe_data(processed_files_data) # Now entities will show up
yield (df_data, persons_md, json.dumps(ocr_result, indent=2), f"({i+1}/{len(processed_files_data)}) Entities for {current_filename}")
# 3. Person Classification / Linking
person_key = get_person_id_and_update_profiles(current_doc_id, entities, person_profiles)
file_data_item["assigned_person_key"] = person_key
file_data_item["status"] = "Classified"
df_data = format_dataframe_data(processed_files_data)
persons_md = format_persons_markdown(person_profiles, processed_files_data) # Now persons_md updates
yield (df_data, persons_md, json.dumps(ocr_result, indent=2), f"({i+1}/{len(processed_files_data)}) Classified {current_filename} -> {person_key}")
final_df_data = format_dataframe_data(processed_files_data)
final_persons_md = format_persons_markdown(person_profiles, processed_files_data)
yield (final_df_data, final_persons_md, "{}", f"All {len(processed_files_data)} documents processed.")
# --- Gradio UI Layout ---
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("# 📄 Intelligent Document Processor & Classifier")
gr.Markdown(
"**Upload multiple documents (images of passports, bank statements, hotel reservations, photos, etc.). "
"The system will perform OCR, attempt to extract key entities, and classify documents by the person they belong to.**\n"
"Ensure `OPENROUTER_API_KEY` is set as a Secret in your Hugging Face Space."
)
if not OPENROUTER_API_KEY:
gr.Markdown("<h3 style='color:red;'>⚠️ ERROR: `OPENROUTER_API_KEY` is not set in Space Secrets! OCR will fail.</h3>")
with gr.Row():
with gr.Column(scale=1):
files_input = gr.Files(label="Upload Document Images (Bulk)", file_count="multiple", type="filepath")
process_button = gr.Button("Process Uploaded Documents", variant="primary")
overall_status_textbox = gr.Textbox(label="Overall Progress", interactive=False, lines=1)
gr.Markdown("---")
gr.Markdown("## Document Processing Details")
# "ID", "Filename", "Status", "Detected Type", "Extracted Name", "Extracted DOB", "Main ID", "Person Key"
dataframe_headers = ["Doc ID (short)", "Filename", "Status", "Detected Type", "Name", "DOB", "Passport No.", "Assigned Person Key"]
document_status_df = gr.Dataframe(
headers=dataframe_headers,
datatype=["str"] * len(dataframe_headers), # All as strings for display simplicity
label="Individual Document Status & Extracted Entities",
row_count=(0, "dynamic"), # Start empty, dynamically grows
col_count=(len(dataframe_headers), "fixed"),
wrap=True
)
ocr_json_output = gr.Code(label="Selected Document OCR JSON", language="json", interactive=False)
gr.Markdown("---")
person_classification_output_md = gr.Markdown("## Classified Persons & Documents\nNo persons identified yet.")
# Event Handlers
process_button.click(
fn=process_uploaded_files,
inputs=[files_input],
outputs=[
document_status_df,
person_classification_output_md,
ocr_json_output, # Temporarily show last OCR here, better if select event works
overall_status_textbox
]
)
@document_status_df.select(inputs=None, outputs=ocr_json_output, show_progress="hidden")
def display_selected_ocr(evt: gr.SelectData):
if evt.index is None or evt.index[0] is None: # evt.index is (row, col)
return "{}" # Nothing selected or invalid selection
selected_row_index = evt.index[0]
if selected_row_index < len(processed_files_data):
selected_doc_data = processed_files_data[selected_row_index]
if selected_doc_data and selected_doc_data["ocr_json"]:
return json.dumps(selected_doc_data["ocr_json"], indent=2)
return "{ \"message\": \"No OCR data found for selected row or selection out of bounds.\" }"
if __name__ == "__main__":
demo.queue().launch(debug=True) # Use queue for longer processes, share=True for Spaces