import gradio as gr import base64 import requests import json import re import os import uuid from datetime import datetime # --- Configuration --- # IMPORTANT: Set your OPENROUTER_API_KEY as a Hugging Face Space Secret OPENROUTER_API_KEY = "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" IMAGE_MODEL = "opengvlab/internvl3-14b:free" # Using the free tier model as specified OPENROUTER_API_URL = "https://openrouter.ai/api/v1/chat/completions" # --- Global State (managed within Gradio's session if possible, or module-level for simplicity here) --- # This will be reset each time the processing function is called. processed_files_data = [] # Stores dicts for each file's details and status person_profiles = {} # Stores dicts for each identified person and their documents # --- Helper Functions --- def extract_json_from_text(text): if not text: return {"error": "Empty text provided for JSON extraction."} match_block = re.search(r"```json\s*(\{.*?\})\s*```", text, re.DOTALL | re.IGNORECASE) if match_block: json_str = match_block.group(1) else: text_stripped = text.strip() if text_stripped.startswith("`") and text_stripped.endswith("`"): json_str = text_stripped[1:-1] else: json_str = text_stripped try: return json.loads(json_str) except json.JSONDecodeError as e: try: first_brace = json_str.find('{') last_brace = json_str.rfind('}') if first_brace != -1 and last_brace != -1 and last_brace > first_brace: potential_json_str = json_str[first_brace : last_brace+1] return json.loads(potential_json_str) else: return {"error": f"Invalid JSON structure: {str(e)}", "original_text": text} except json.JSONDecodeError as e2: return {"error": f"Invalid JSON structure after attempting substring: {str(e2)}", "original_text": text} def get_ocr_prompt(): return f"""You are an advanced OCR and information extraction AI. Your task is to meticulously analyze this image and extract all relevant information. Output Format Instructions: Provide your response as a SINGLE, VALID JSON OBJECT. Do not include any explanatory text before or after the JSON. The JSON object should have the following top-level keys: - "document_type_detected": (string) Your best guess of the specific document type (e.g., "Passport", "National ID Card", "Driver's License", "Visa Sticker", "Hotel Confirmation Voucher", "Bank Statement", "Photo of a person"). - "extracted_fields": (object) A key-value map of all extracted information. Be comprehensive. Examples: - For passports/IDs: "Surname", "Given Names", "Full Name", "Document Number", "Nationality", "Date of Birth", "Sex", "Place of Birth", "Date of Issue", "Date of Expiry", "Issuing Authority", "Country Code". - For hotel reservations: "Guest Name", "Hotel Name", "Booking Reference", "Check-in Date", "Check-out Date". - For bank statements: "Account Holder Name", "Account Number", "Bank Name", "Statement Period", "Ending Balance". - For photos: "Description" (e.g., "Portrait of a person", "Group photo at a location"), "People Present" (array of strings if multiple). - "mrz_data": (object or null) If a Machine Readable Zone (MRZ) is present: - "raw_mrz_lines": (array of strings) Each line of the MRZ. - "parsed_mrz": (object) Key-value pairs of parsed MRZ fields. If no MRZ, this field should be null. - "full_text_ocr": (string) Concatenation of all text found on the document. Extraction Guidelines: 1. Prioritize accuracy. 2. Extract all visible text. Include "Full Name" by combining given and surnames if possible. 3. For dates, try to use ISO 8601 format (YYYY-MM-DD) if possible, but retain original format if conversion is ambiguous. Ensure the entire output strictly adheres to the JSON format. """ def call_openrouter_ocr(image_filepath): if not OPENROUTER_API_KEY: return {"error": "OpenRouter API Key not configured."} try: with open(image_filepath, "rb") as f: encoded_image = base64.b64encode(f.read()).decode("utf-8") mime_type = "image/jpeg" if image_filepath.lower().endswith(".png"): mime_type = "image/png" elif image_filepath.lower().endswith(".webp"): mime_type = "image/webp" data_url = f"data:{mime_type};base64,{encoded_image}" prompt_text = get_ocr_prompt() payload = { "model": IMAGE_MODEL, "messages": [ { "role": "user", "content": [ {"type": "text", "text": prompt_text}, {"type": "image_url", "image_url": {"url": data_url}} ] } ], "max_tokens": 3500, "temperature": 0.1, } headers = { "Authorization": f"Bearer {OPENROUTER_API_KEY}", "Content-Type": "application/json", "HTTP-Referer": "https://huggingface.co/spaces/YOUR_SPACE", "X-Title": "Gradio Document Processor" } response = requests.post(OPENROUTER_API_URL, headers=headers, json=payload, timeout=180) response.raise_for_status() result = response.json() if "choices" in result and result["choices"]: raw_content = result["choices"][0]["message"]["content"] return extract_json_from_text(raw_content) else: return {"error": "No 'choices' in API response from OpenRouter.", "details": result} except requests.exceptions.Timeout: return {"error": "API request timed out."} except requests.exceptions.RequestException as e: error_message = f"API Request Error: {str(e)}" if hasattr(e, 'response') and e.response is not None: error_message += f" Status: {e.response.status_code}, Response: {e.response.text}" return {"error": error_message} except Exception as e: return {"error": f"An unexpected error occurred during OCR: {str(e)}"} def extract_entities_from_ocr(ocr_json): if not ocr_json or "extracted_fields" not in ocr_json or not isinstance(ocr_json.get("extracted_fields"), dict): doc_type_from_ocr = "Unknown" if isinstance(ocr_json, dict): # ocr_json itself might be an error dict doc_type_from_ocr = ocr_json.get("document_type_detected", "Unknown (error in OCR)") return {"name": None, "dob": None, "passport_no": None, "doc_type": doc_type_from_ocr} fields = ocr_json["extracted_fields"] doc_type = ocr_json.get("document_type_detected", "Unknown") name_keys = ["full name", "name", "account holder name", "guest name"] dob_keys = ["date of birth", "dob"] passport_keys = ["document number", "passport number"] extracted_name = None for key in name_keys: for field_key, value in fields.items(): if key == field_key.lower(): extracted_name = str(value) if value else None break if extracted_name: break extracted_dob = None for key in dob_keys: for field_key, value in fields.items(): if key == field_key.lower(): extracted_dob = str(value) if value else None break if extracted_dob: break extracted_passport_no = None for key in passport_keys: for field_key, value in fields.items(): if key == field_key.lower(): extracted_passport_no = str(value).replace(" ", "").upper() if value else None break if extracted_passport_no: break return { "name": extracted_name, "dob": extracted_dob, "passport_no": extracted_passport_no, "doc_type": doc_type } def normalize_name(name): if not name: return "" return "".join(filter(str.isalnum, name)).lower() def get_person_id_and_update_profiles(doc_id, entities, current_persons_data): passport_no = entities.get("passport_no") name = entities.get("name") dob = entities.get("dob") if passport_no: for p_key, p_data in current_persons_data.items(): if passport_no in p_data.get("passport_numbers", set()): p_data["doc_ids"].add(doc_id) if name and not p_data.get("canonical_name"): p_data["canonical_name"] = name if dob and not p_data.get("canonical_dob"): p_data["canonical_dob"] = dob return p_key new_person_key = f"person_{passport_no}" current_persons_data[new_person_key] = { "canonical_name": name, "canonical_dob": dob, "names": {normalize_name(name)} if name else set(), "dobs": {dob} if dob else set(), "passport_numbers": {passport_no}, "doc_ids": {doc_id}, "display_name": name or f"Person (ID: {passport_no})" } return new_person_key if name and dob: norm_name = normalize_name(name) composite_key_nd = f"{norm_name}_{dob}" for p_key, p_data in current_persons_data.items(): if norm_name in p_data.get("names", set()) and dob in p_data.get("dobs", set()): p_data["doc_ids"].add(doc_id) return p_key new_person_key = f"person_{composite_key_nd}_{str(uuid.uuid4())[:4]}" current_persons_data[new_person_key] = { "canonical_name": name, "canonical_dob": dob, "names": {norm_name}, "dobs": {dob}, "passport_numbers": set(), "doc_ids": {doc_id}, "display_name": name } return new_person_key if name: norm_name = normalize_name(name) new_person_key = f"person_{norm_name}_{str(uuid.uuid4())[:4]}" current_persons_data[new_person_key] = { "canonical_name": name, "canonical_dob": None, "names": {norm_name}, "dobs": set(), "passport_numbers": set(), "doc_ids": {doc_id}, "display_name": name } return new_person_key generic_person_key = f"unidentified_person_{str(uuid.uuid4())[:6]}" current_persons_data[generic_person_key] = { "canonical_name": "Unknown", "canonical_dob": None, "names": set(), "dobs": set(), "passport_numbers": set(), "doc_ids": {doc_id}, "display_name": f"Unknown Person ({doc_id[:6]})" } return generic_person_key def format_dataframe_data(current_files_data): df_rows = [] for f_data in current_files_data: entities = f_data.get("entities") or {} # CORRECTED LINE HERE df_rows.append([ f_data.get("doc_id", "N/A")[:8], f_data.get("filename", "N/A"), f_data.get("status", "N/A"), entities.get("doc_type", "N/A"), entities.get("name", "N/A"), entities.get("dob", "N/A"), entities.get("passport_no", "N/A"), f_data.get("assigned_person_key", "N/A") ]) return df_rows def format_persons_markdown(current_persons_data, current_files_data): if not current_persons_data: return "No persons identified yet." md_parts = ["## Classified Persons & Documents\n"] for p_key, p_data in current_persons_data.items(): display_name = p_data.get('display_name', p_key) md_parts.append(f"### Person: {display_name} (Profile Key: {p_key})") if p_data.get("canonical_dob"): md_parts.append(f"* DOB: {p_data['canonical_dob']}") if p_data.get("passport_numbers"): md_parts.append(f"* Passport(s): {', '.join(p_data['passport_numbers'])}") md_parts.append("* Documents:") doc_ids_for_person = p_data.get("doc_ids", set()) if doc_ids_for_person: for doc_id in doc_ids_for_person: doc_detail = next((f for f in current_files_data if f["doc_id"] == doc_id), None) if doc_detail: filename = doc_detail.get("filename", "Unknown File") doc_entities = doc_detail.get("entities") or {} doc_type = doc_entities.get("doc_type", "Unknown Type") md_parts.append(f" - {filename} (`{doc_type}`)") else: md_parts.append(f" - Document ID: {doc_id[:8]} (details error)") else: md_parts.append(" - No documents currently assigned.") md_parts.append("\n---\n") return "\n".join(md_parts) def process_uploaded_files(files_list, progress=gr.Progress(track_tqdm=True)): global processed_files_data, person_profiles processed_files_data = [] person_profiles = {} if not OPENROUTER_API_KEY: yield ( [["N/A", "ERROR", "OpenRouter API Key not configured.", "N/A", "N/A", "N/A", "N/A", "N/A"]], "Error: OpenRouter API Key not configured. Please set it in Space Secrets.", "{}", "API Key Missing. Processing halted." ) return if not files_list: yield ([], "No files uploaded.", "{}", "Upload files to begin.") return for i, file_obj in enumerate(files_list): doc_uid = str(uuid.uuid4()) processed_files_data.append({ "doc_id": doc_uid, "filename": os.path.basename(file_obj.name if hasattr(file_obj, 'name') else f"file_{i+1}.unknown"), "filepath": file_obj.name if hasattr(file_obj, 'name') else None, # file_obj itself is filepath if from gr.Files type="filepath" "status": "Queued", "ocr_json": None, "entities": None, "assigned_person_key": None }) initial_df_data = format_dataframe_data(processed_files_data) initial_persons_md = format_persons_markdown(person_profiles, processed_files_data) yield (initial_df_data, initial_persons_md, "{}", f"Initialized. Found {len(files_list)} files.") for i, file_data_item in enumerate(progress.tqdm(processed_files_data, desc="Processing Documents")): current_doc_id = file_data_item["doc_id"] current_filename = file_data_item["filename"] if not file_data_item["filepath"]: # Check if filepath is valid file_data_item["status"] = "Error: Invalid file path" df_data = format_dataframe_data(processed_files_data) persons_md = format_persons_markdown(person_profiles, processed_files_data) yield(df_data, persons_md, "{}", f"({i+1}/{len(processed_files_data)}) Error with file {current_filename}") continue file_data_item["status"] = "OCR in Progress..." df_data = format_dataframe_data(processed_files_data) persons_md = format_persons_markdown(person_profiles, processed_files_data) yield (df_data, persons_md, "{}", f"({i+1}/{len(processed_files_data)}) OCR for: {current_filename}") ocr_result = call_openrouter_ocr(file_data_item["filepath"]) file_data_item["ocr_json"] = ocr_result if "error" in ocr_result: file_data_item["status"] = f"OCR Error: {str(ocr_result['error'])[:50]}..." df_data = format_dataframe_data(processed_files_data) yield (df_data, persons_md, json.dumps(ocr_result, indent=2), f"({i+1}/{len(processed_files_data)}) OCR Error on {current_filename}") continue file_data_item["status"] = "OCR Done. Extracting Entities..." df_data = format_dataframe_data(processed_files_data) yield (df_data, persons_md, json.dumps(ocr_result, indent=2), f"({i+1}/{len(processed_files_data)}) OCR Done for {current_filename}") entities = extract_entities_from_ocr(ocr_result) file_data_item["entities"] = entities file_data_item["status"] = "Entities Extracted. Classifying..." df_data = format_dataframe_data(processed_files_data) yield (df_data, persons_md, json.dumps(ocr_result, indent=2), f"({i+1}/{len(processed_files_data)}) Entities for {current_filename}") person_key = get_person_id_and_update_profiles(current_doc_id, entities, person_profiles) file_data_item["assigned_person_key"] = person_key file_data_item["status"] = "Classified" df_data = format_dataframe_data(processed_files_data) persons_md = format_persons_markdown(person_profiles, processed_files_data) yield (df_data, persons_md, json.dumps(ocr_result, indent=2), f"({i+1}/{len(processed_files_data)}) Classified {current_filename} -> {person_key}") final_df_data = format_dataframe_data(processed_files_data) final_persons_md = format_persons_markdown(person_profiles, processed_files_data) yield (final_df_data, final_persons_md, "{}", f"All {len(processed_files_data)} documents processed.") with gr.Blocks(theme=gr.themes.Soft()) as demo: gr.Markdown("# 📄 Intelligent Document Processor & Classifier") gr.Markdown( "**Upload multiple documents (images of passports, bank statements, hotel reservations, photos, etc.). " "The system will perform OCR, attempt to extract key entities, and classify documents by the person they belong to.**\n" "Ensure `OPENROUTER_API_KEY` is set as a Secret in your Hugging Face Space." ) if not OPENROUTER_API_KEY: gr.Markdown("

⚠️ ERROR: `OPENROUTER_API_KEY` is not set in Space Secrets! OCR will fail.

") with gr.Row(): with gr.Column(scale=1): files_input = gr.Files(label="Upload Document Images (Bulk)", file_count="multiple", type="filepath") # Using filepath process_button = gr.Button("🚀 Process Uploaded Documents", variant="primary") overall_status_textbox = gr.Textbox(label="Overall Progress", interactive=False, lines=1) gr.Markdown("---") gr.Markdown("## Document Processing Details") dataframe_headers = ["Doc ID (short)", "Filename", "Status", "Detected Type", "Name", "DOB", "Passport No.", "Assigned Person Key"] document_status_df = gr.Dataframe( headers=dataframe_headers, datatype=["str"] * len(dataframe_headers), label="Individual Document Status & Extracted Entities", row_count=(1, "dynamic"), # Start with 1 row, dynamically grows col_count=(len(dataframe_headers), "fixed"), wrap=True ) ocr_json_output = gr.Code(label="Selected Document OCR JSON", language="json", interactive=False) gr.Markdown("---") person_classification_output_md = gr.Markdown("## Classified Persons & Documents\nNo persons identified yet.") process_button.click( fn=process_uploaded_files, inputs=[files_input], outputs=[ document_status_df, person_classification_output_md, ocr_json_output, overall_status_textbox ] ) @document_status_df.select(inputs=None, outputs=ocr_json_output, show_progress="hidden") def display_selected_ocr(evt: gr.SelectData): if evt.index is None or evt.index[0] is None: return "{}" selected_row_index = evt.index[0] # Ensure processed_files_data is accessible here. If it's truly global, it should be. # For safety, one might pass it or make it part of a class if this were more complex. if 0 <= selected_row_index < len(processed_files_data): selected_doc_data = processed_files_data[selected_row_index] if selected_doc_data and selected_doc_data.get("ocr_json"): # Check if ocr_json is already a dict, if not, try to parse (though it should be) ocr_data_to_display = selected_doc_data["ocr_json"] if isinstance(ocr_data_to_display, str): # Should not happen if stored correctly try: ocr_data_to_display = json.loads(ocr_data_to_display) except json.JSONDecodeError: return json.dumps({"error": "Stored OCR data is not valid JSON string."}, indent=2) return json.dumps(ocr_data_to_display, indent=2, ensure_ascii=False) return json.dumps({ "message": "No OCR data found for selected row or selection out of bounds (check if processing is complete). Current rows: " + str(len(processed_files_data))}, indent=2) if __name__ == "__main__": demo.queue().launch(debug=True, share=os.environ.get("GRADIO_SHARE", "true").lower() == "true")