Spaces:
Runtime error
Runtime error
import gradio as gr | |
import base64 | |
import requests | |
import json | |
import re | |
import os | |
import uuid | |
from datetime import datetime | |
# --- Configuration --- | |
# IMPORTANT: Set your OPENROUTER_API_KEY as a Hugging Face Space Secret | |
OPENROUTER_API_KEY = "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" | |
IMAGE_MODEL = "opengvlab/internvl3-14b:free" # Using the free tier model as specified | |
OPENROUTER_API_URL = "https://openrouter.ai/api/v1/chat/completions" | |
# --- Global State (managed within Gradio's session if possible, or module-level for simplicity here) --- | |
# This will be reset each time the processing function is called. | |
# For a multi-user or more robust app, session state or a proper backend DB would be needed. | |
processed_files_data = [] # Stores dicts for each file's details and status | |
person_profiles = {} # Stores dicts for each identified person and their documents | |
# --- Helper Functions --- | |
def extract_json_from_text(text): | |
""" | |
Extracts a JSON object from a string, trying common markdown and direct JSON. | |
""" | |
if not text: | |
return {"error": "Empty text provided for JSON extraction."} | |
# Try to match ```json ... ``` code block | |
match_block = re.search(r"```json\s*(\{.*?\})\s*```", text, re.DOTALL | re.IGNORECASE) | |
if match_block: | |
json_str = match_block.group(1) | |
else: | |
# If no block, assume the text itself might be JSON or wrapped in single backticks | |
text_stripped = text.strip() | |
if text_stripped.startswith("`") and text_stripped.endswith("`"): | |
json_str = text_stripped[1:-1] | |
else: | |
json_str = text_stripped # Assume it's direct JSON | |
try: | |
return json.loads(json_str) | |
except json.JSONDecodeError as e: | |
# Fallback: Try to find the first '{' and last '}' if initial parsing fails | |
try: | |
first_brace = json_str.find('{') | |
last_brace = json_str.rfind('}') | |
if first_brace != -1 and last_brace != -1 and last_brace > first_brace: | |
potential_json_str = json_str[first_brace : last_brace+1] | |
return json.loads(potential_json_str) | |
else: | |
return {"error": f"Invalid JSON structure: {str(e)}", "original_text": text} | |
except json.JSONDecodeError as e2: | |
return {"error": f"Invalid JSON structure after attempting substring: {str(e2)}", "original_text": text} | |
def get_ocr_prompt(): | |
return f"""You are an advanced OCR and information extraction AI. | |
Your task is to meticulously analyze this image and extract all relevant information. | |
Output Format Instructions: | |
Provide your response as a SINGLE, VALID JSON OBJECT. Do not include any explanatory text before or after the JSON. | |
The JSON object should have the following top-level keys: | |
- "document_type_detected": (string) Your best guess of the specific document type (e.g., "Passport", "National ID Card", "Driver's License", "Visa Sticker", "Hotel Confirmation Voucher", "Bank Statement", "Photo of a person"). | |
- "extracted_fields": (object) A key-value map of all extracted information. Be comprehensive. Examples: | |
- For passports/IDs: "Surname", "Given Names", "Full Name", "Document Number", "Nationality", "Date of Birth", "Sex", "Place of Birth", "Date of Issue", "Date of Expiry", "Issuing Authority", "Country Code". | |
- For hotel reservations: "Guest Name", "Hotel Name", "Booking Reference", "Check-in Date", "Check-out Date". | |
- For bank statements: "Account Holder Name", "Account Number", "Bank Name", "Statement Period", "Ending Balance". | |
- For photos: "Description" (e.g., "Portrait of a person", "Group photo at a location"), "People Present" (array of strings if multiple). | |
- "mrz_data": (object or null) If a Machine Readable Zone (MRZ) is present: | |
- "raw_mrz_lines": (array of strings) Each line of the MRZ. | |
- "parsed_mrz": (object) Key-value pairs of parsed MRZ fields. | |
If no MRZ, this field should be null. | |
- "full_text_ocr": (string) Concatenation of all text found on the document. | |
Extraction Guidelines: | |
1. Prioritize accuracy. | |
2. Extract all visible text. Include "Full Name" by combining given and surnames if possible. | |
3. For dates, try to use ISO 8601 format (YYYY-MM-DD) if possible, but retain original format if conversion is ambiguous. | |
Ensure the entire output strictly adheres to the JSON format. | |
""" | |
def call_openrouter_ocr(image_filepath): | |
if not OPENROUTER_API_KEY: | |
return {"error": "OpenRouter API Key not configured."} | |
try: | |
with open(image_filepath, "rb") as f: | |
encoded_image = base64.b64encode(f.read()).decode("utf-8") | |
# Basic MIME type guessing, default to jpeg | |
mime_type = "image/jpeg" | |
if image_filepath.lower().endswith(".png"): | |
mime_type = "image/png" | |
elif image_filepath.lower().endswith(".webp"): | |
mime_type = "image/webp" | |
data_url = f"data:{mime_type};base64,{encoded_image}" | |
prompt_text = get_ocr_prompt() | |
payload = { | |
"model": IMAGE_MODEL, | |
"messages": [ | |
{ | |
"role": "user", | |
"content": [ | |
{"type": "text", "text": prompt_text}, | |
{"type": "image_url", "image_url": {"url": data_url}} | |
] | |
} | |
], | |
"max_tokens": 3500, # Increased for detailed JSON | |
"temperature": 0.1, | |
} | |
headers = { | |
"Authorization": f"Bearer {OPENROUTER_API_KEY}", | |
"Content-Type": "application/json", | |
"HTTP-Referer": "https://huggingface.co/spaces/DoClassifier", # Optional: Update with your Space URL | |
"X-Title": "DoClassifier Processor" # Optional | |
} | |
response = requests.post(OPENROUTER_API_URL, headers=headers, json=payload, timeout=180) # 3 min timeout | |
response.raise_for_status() | |
result = response.json() | |
if "choices" in result and result["choices"]: | |
raw_content = result["choices"][0]["message"]["content"] | |
return extract_json_from_text(raw_content) | |
else: | |
return {"error": "No 'choices' in API response from OpenRouter.", "details": result} | |
except requests.exceptions.Timeout: | |
return {"error": "API request timed out."} | |
except requests.exceptions.RequestException as e: | |
error_message = f"API Request Error: {str(e)}" | |
if hasattr(e, 'response') and e.response is not None: | |
error_message += f" Status: {e.response.status_code}, Response: {e.response.text}" | |
return {"error": error_message} | |
except Exception as e: | |
return {"error": f"An unexpected error occurred during OCR: {str(e)}"} | |
def extract_entities_from_ocr(ocr_json): | |
if not ocr_json or "extracted_fields" not in ocr_json or not isinstance(ocr_json["extracted_fields"], dict): | |
return {"name": None, "dob": None, "passport_no": None, "doc_type": ocr_json.get("document_type_detected", "Unknown")} | |
fields = ocr_json["extracted_fields"] | |
doc_type = ocr_json.get("document_type_detected", "Unknown") | |
# Normalize potential field names (case-insensitive search) | |
name_keys = ["full name", "name", "account holder name", "guest name"] | |
dob_keys = ["date of birth", "dob"] | |
passport_keys = ["document number", "passport number"] | |
extracted_name = None | |
for key in name_keys: | |
for field_key, value in fields.items(): | |
if key == field_key.lower(): | |
extracted_name = str(value) if value else None | |
break | |
if extracted_name: | |
break | |
extracted_dob = None | |
for key in dob_keys: | |
for field_key, value in fields.items(): | |
if key == field_key.lower(): | |
extracted_dob = str(value) if value else None | |
break | |
if extracted_dob: | |
break | |
extracted_passport_no = None | |
for key in passport_keys: | |
for field_key, value in fields.items(): | |
if key == field_key.lower(): | |
extracted_passport_no = str(value).replace(" ", "").upper() if value else None # Normalize | |
break | |
if extracted_passport_no: | |
break | |
return { | |
"name": extracted_name, | |
"dob": extracted_dob, | |
"passport_no": extracted_passport_no, | |
"doc_type": doc_type | |
} | |
def normalize_name(name): | |
if not name: return "" | |
return "".join(filter(str.isalnum, name)).lower() | |
def get_person_id_and_update_profiles(doc_id, entities, current_persons_data): | |
""" | |
Tries to assign a document to an existing person or creates a new one. | |
Returns a person_key. | |
Updates current_persons_data in place. | |
""" | |
passport_no = entities.get("passport_no") | |
name = entities.get("name") | |
dob = entities.get("dob") | |
# 1. Match by Passport Number (strongest identifier) | |
if passport_no: | |
for p_key, p_data in current_persons_data.items(): | |
if passport_no in p_data.get("passport_numbers", set()): | |
p_data["doc_ids"].add(doc_id) | |
# Update person profile with potentially new name/dob if current is missing | |
if name and not p_data.get("canonical_name"): p_data["canonical_name"] = name | |
if dob and not p_data.get("canonical_dob"): p_data["canonical_dob"] = dob | |
return p_key | |
# New person based on passport number | |
new_person_key = f"person_{passport_no}" # Or more robust ID generation | |
current_persons_data[new_person_key] = { | |
"canonical_name": name, | |
"canonical_dob": dob, | |
"names": {normalize_name(name)} if name else set(), | |
"dobs": {dob} if dob else set(), | |
"passport_numbers": {passport_no}, | |
"doc_ids": {doc_id}, | |
"display_name": name or f"Person (ID: {passport_no})" | |
} | |
return new_person_key | |
# 2. Match by Normalized Name + DOB (if passport not found or not present) | |
if name and dob: | |
norm_name = normalize_name(name) | |
composite_key_nd = f"{norm_name}_{dob}" | |
for p_key, p_data in current_persons_data.items(): | |
# Check if this name and dob combo has been seen for this person | |
if norm_name in p_data.get("names", set()) and dob in p_data.get("dobs", set()): | |
p_data["doc_ids"].add(doc_id) | |
return p_key | |
# New person based on name and DOB | |
new_person_key = f"person_{composite_key_nd}_{str(uuid.uuid4())[:4]}" | |
current_persons_data[new_person_key] = { | |
"canonical_name": name, | |
"canonical_dob": dob, | |
"names": {norm_name}, | |
"dobs": {dob}, | |
"passport_numbers": set(), | |
"doc_ids": {doc_id}, | |
"display_name": name | |
} | |
return new_person_key | |
# 3. If only name, less reliable, create new person (could add fuzzy matching later) | |
if name: | |
norm_name = normalize_name(name) | |
# Check if a person with just this name exists and has no other strong identifiers yet | |
# This part can be made more robust, for now, it might create more splits | |
new_person_key = f"person_{norm_name}_{str(uuid.uuid4())[:4]}" | |
current_persons_data[new_person_key] = { | |
"canonical_name": name, "canonical_dob": None, | |
"names": {norm_name}, "dobs": set(), "passport_numbers": set(), | |
"doc_ids": {doc_id}, "display_name": name | |
} | |
return new_person_key | |
# 4. Unclassifiable for now, assign a generic unique person key | |
generic_person_key = f"unidentified_person_{str(uuid.uuid4())[:6]}" | |
current_persons_data[generic_person_key] = { | |
"canonical_name": "Unknown", "canonical_dob": None, | |
"names": set(), "dobs": set(), "passport_numbers": set(), | |
"doc_ids": {doc_id}, "display_name": f"Unknown Person ({doc_id[:6]})" | |
} | |
return generic_person_key | |
def format_dataframe_data(current_files_data): | |
# Headers for the dataframe | |
# "ID", "Filename", "Status", "Detected Type", "Extracted Name", "Extracted DOB", "Main ID", "Person Key" | |
df_rows = [] | |
for f_data in current_files_data: | |
entities = f_data.get("entities") or {} | |
df_rows.append([ | |
f_data["doc_id"][:8], # Short ID | |
f_data["filename"], | |
f_data["status"], | |
entities.get("doc_type", "N/A"), | |
entities.get("name", "N/A"), | |
entities.get("dob", "N/A"), | |
entities.get("passport_no", "N/A"), | |
f_data.get("assigned_person_key", "N/A") | |
]) | |
return df_rows | |
def format_persons_markdown(current_persons_data, current_files_data): | |
if not current_persons_data: | |
return "No persons identified yet." | |
md_parts = ["## Classified Persons & Documents\n"] | |
for p_key, p_data in current_persons_data.items(): | |
display_name = p_data.get('display_name', p_key) | |
md_parts.append(f"### Person: {display_name} (Profile Key: {p_key})") | |
if p_data.get("canonical_dob"): md_parts.append(f"* DOB: {p_data['canonical_dob']}") | |
if p_data.get("passport_numbers"): md_parts.append(f"* Passport(s): {', '.join(p_data['passport_numbers'])}") | |
md_parts.append("* Documents:") | |
doc_ids_for_person = p_data.get("doc_ids", set()) | |
if doc_ids_for_person: | |
for doc_id in doc_ids_for_person: | |
# Find the filename and detected type from current_files_data | |
doc_detail = next((f for f in current_files_data if f["doc_id"] == doc_id), None) | |
if doc_detail: | |
filename = doc_detail["filename"] | |
doc_type = doc_detail.get("entities", {}).get("doc_type", "Unknown Type") | |
md_parts.append(f" - {filename} (`{doc_type}`)") | |
else: | |
md_parts.append(f" - Document ID: {doc_id[:8]} (details not found, unexpected)") | |
else: | |
md_parts.append(" - No documents currently assigned.") | |
md_parts.append("\n---\n") | |
return "\n".join(md_parts) | |
# --- Main Gradio Processing Function (Generator) --- | |
def process_uploaded_files(files_list, progress=gr.Progress(track_tqdm=True)): | |
global processed_files_data, person_profiles # Reset global state for each run | |
processed_files_data = [] | |
person_profiles = {} | |
if not OPENROUTER_API_KEY: | |
yield ( | |
[["N/A", "ERROR", "OpenRouter API Key not configured.", "N/A", "N/A", "N/A", "N/A", "N/A"]], | |
"Error: OpenRouter API Key not configured. Please set it in Space Secrets.", | |
"{}", "API Key Missing. Processing halted." | |
) | |
return | |
if not files_list: | |
yield ([], "No files uploaded.", "{}", "Upload files to begin.") | |
return | |
# Initialize processed_files_data | |
for i, file_obj in enumerate(files_list): | |
doc_uid = str(uuid.uuid4()) | |
processed_files_data.append({ | |
"doc_id": doc_uid, | |
"filename": os.path.basename(file_obj.name), # file_obj.name is the temp path | |
"filepath": file_obj.name, | |
"status": "Queued", | |
"ocr_json": None, | |
"entities": None, | |
"assigned_person_key": None | |
}) | |
initial_df_data = format_dataframe_data(processed_files_data) | |
initial_persons_md = format_persons_markdown(person_profiles, processed_files_data) | |
yield (initial_df_data, initial_persons_md, "{}", f"Initialized. Found {len(files_list)} files.") | |
# Iterate and process each file | |
for i, file_data_item in enumerate(progress.tqdm(processed_files_data, desc="Processing Documents")): | |
current_doc_id = file_data_item["doc_id"] | |
current_filename = file_data_item["filename"] | |
# 1. OCR Processing | |
file_data_item["status"] = "OCR in Progress..." | |
df_data = format_dataframe_data(processed_files_data) | |
persons_md = format_persons_markdown(person_profiles, processed_files_data) # No change yet | |
yield (df_data, persons_md, "{}", f"({i+1}/{len(processed_files_data)}) OCR for: {current_filename}") | |
ocr_result = call_openrouter_ocr(file_data_item["filepath"]) | |
file_data_item["ocr_json"] = ocr_result # Store full JSON | |
if "error" in ocr_result: | |
file_data_item["status"] = f"OCR Error: {ocr_result['error'][:50]}..." # Truncate long errors | |
df_data = format_dataframe_data(processed_files_data) | |
yield (df_data, persons_md, json.dumps(ocr_result, indent=2), f"({i+1}/{len(processed_files_data)}) OCR Error on {current_filename}") | |
continue # Move to next file | |
file_data_item["status"] = "OCR Done. Extracting Entities..." | |
df_data = format_dataframe_data(processed_files_data) | |
yield (df_data, persons_md, json.dumps(ocr_result, indent=2), f"({i+1}/{len(processed_files_data)}) OCR Done for {current_filename}") | |
# 2. Entity Extraction | |
entities = extract_entities_from_ocr(ocr_result) | |
file_data_item["entities"] = entities | |
file_data_item["status"] = "Entities Extracted. Classifying..." | |
df_data = format_dataframe_data(processed_files_data) # Now entities will show up | |
yield (df_data, persons_md, json.dumps(ocr_result, indent=2), f"({i+1}/{len(processed_files_data)}) Entities for {current_filename}") | |
# 3. Person Classification / Linking | |
person_key = get_person_id_and_update_profiles(current_doc_id, entities, person_profiles) | |
file_data_item["assigned_person_key"] = person_key | |
file_data_item["status"] = "Classified" | |
df_data = format_dataframe_data(processed_files_data) | |
persons_md = format_persons_markdown(person_profiles, processed_files_data) # Now persons_md updates | |
yield (df_data, persons_md, json.dumps(ocr_result, indent=2), f"({i+1}/{len(processed_files_data)}) Classified {current_filename} -> {person_key}") | |
final_df_data = format_dataframe_data(processed_files_data) | |
final_persons_md = format_persons_markdown(person_profiles, processed_files_data) | |
yield (final_df_data, final_persons_md, "{}", f"All {len(processed_files_data)} documents processed.") | |
# --- Gradio UI Layout --- | |
with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
gr.Markdown("# 📄 Intelligent Document Processor & Classifier") | |
gr.Markdown( | |
"**Upload multiple documents (images of passports, bank statements, hotel reservations, photos, etc.). " | |
"The system will perform OCR, attempt to extract key entities, and classify documents by the person they belong to.**\n" | |
"Ensure `OPENROUTER_API_KEY` is set as a Secret in your Hugging Face Space." | |
) | |
if not OPENROUTER_API_KEY: | |
gr.Markdown("<h3 style='color:red;'>⚠️ ERROR: `OPENROUTER_API_KEY` is not set in Space Secrets! OCR will fail.</h3>") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
files_input = gr.Files(label="Upload Document Images (Bulk)", file_count="multiple", type="filepath") | |
process_button = gr.Button("Process Uploaded Documents", variant="primary") | |
overall_status_textbox = gr.Textbox(label="Overall Progress", interactive=False, lines=1) | |
gr.Markdown("---") | |
gr.Markdown("## Document Processing Details") | |
# "ID", "Filename", "Status", "Detected Type", "Extracted Name", "Extracted DOB", "Main ID", "Person Key" | |
dataframe_headers = ["Doc ID (short)", "Filename", "Status", "Detected Type", "Name", "DOB", "Passport No.", "Assigned Person Key"] | |
document_status_df = gr.Dataframe( | |
headers=dataframe_headers, | |
datatype=["str"] * len(dataframe_headers), # All as strings for display simplicity | |
label="Individual Document Status & Extracted Entities", | |
row_count=(0, "dynamic"), # Start empty, dynamically grows | |
col_count=(len(dataframe_headers), "fixed"), | |
wrap=True | |
) | |
ocr_json_output = gr.Code(label="Selected Document OCR JSON", language="json", interactive=False) | |
gr.Markdown("---") | |
person_classification_output_md = gr.Markdown("## Classified Persons & Documents\nNo persons identified yet.") | |
# Event Handlers | |
process_button.click( | |
fn=process_uploaded_files, | |
inputs=[files_input], | |
outputs=[ | |
document_status_df, | |
person_classification_output_md, | |
ocr_json_output, # Temporarily show last OCR here, better if select event works | |
overall_status_textbox | |
] | |
) | |
def display_selected_ocr(evt: gr.SelectData): | |
if evt.index is None or evt.index[0] is None: # evt.index is (row, col) | |
return "{}" # Nothing selected or invalid selection | |
selected_row_index = evt.index[0] | |
if selected_row_index < len(processed_files_data): | |
selected_doc_data = processed_files_data[selected_row_index] | |
if selected_doc_data and selected_doc_data["ocr_json"]: | |
return json.dumps(selected_doc_data["ocr_json"], indent=2) | |
return "{ \"message\": \"No OCR data found for selected row or selection out of bounds.\" }" | |
if __name__ == "__main__": | |
demo.queue().launch(debug=True) # Use queue for longer processes, share=True for Spaces |