Spaces:
Runtime error
Runtime error
msIntui
feat: update model downloading from Azure Blob Storage and environment configuration
e2c1993 | import os | |
| import base64 | |
| import gradio as gr | |
| import json | |
| from datetime import datetime | |
| from symbol_detection import run_detection_with_optimal_threshold | |
| from line_detection_ai import DiagramDetectionPipeline, LineDetector, LineConfig, ImageConfig, DebugHandler, \ | |
| PointConfig, JunctionConfig, PointDetector, JunctionDetector, SymbolConfig, SymbolDetector, TagConfig, TagDetector | |
| from data_aggregation_ai import DataAggregator | |
| from chatbot_agent import get_assistant_response | |
| from storage import StorageFactory, LocalStorage | |
| import traceback | |
| from text_detection_combined import process_drawing | |
| from pathlib import Path | |
| from pdf_processor import DocumentProcessor | |
| import networkx as nx | |
| import logging | |
| import matplotlib.pyplot as plt | |
| from dotenv import load_dotenv | |
| import torch | |
| from graph_visualization import create_graph_visualization | |
| import shutil | |
| from detection_schema import BBox # Add this import | |
| import cv2 | |
| import numpy as np | |
| import time | |
| from huggingface_hub import HfApi, login | |
| from download_models import download_from_azure | |
| # Load environment variables from .env file | |
| load_dotenv() | |
| # Configure logging at the start of the file | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(levelname)s - %(message)s', | |
| datefmt='%Y-%m-%d %H:%M:%S' | |
| ) | |
| # Get logger for this module | |
| logger = logging.getLogger(__name__) | |
| # Disable duplicate logs from other modules | |
| logging.getLogger('PIL').setLevel(logging.WARNING) | |
| logging.getLogger('matplotlib').setLevel(logging.WARNING) | |
| logging.getLogger('gradio').setLevel(logging.WARNING) | |
| logging.getLogger('networkx').setLevel(logging.WARNING) | |
| logging.getLogger('line_detection_ai').setLevel(logging.WARNING) | |
| logging.getLogger('symbol_detection').setLevel(logging.WARNING) | |
| # Only log important messages | |
| def log_process_step(message, level=logging.INFO): | |
| """Log processing steps with appropriate level""" | |
| if level >= logging.WARNING: | |
| logger.log(level, message) | |
| elif "completed" in message.lower() or "generated" in message.lower(): | |
| logger.info(message) | |
| # Helper function to format timestamps | |
| def get_timestamp(): | |
| return datetime.now().strftime('%Y-%m-%d %H:%M:%S') | |
| def format_message(role, content): | |
| """Format message for chatbot history.""" | |
| return {"role": role, "content": content} | |
| # Load avatar images for agents | |
| localStorage = LocalStorage() | |
| agent_avatar = base64.b64encode(localStorage.load_file("assets/AiAgent.png")).decode() | |
| llm_avatar = base64.b64encode(localStorage.load_file("assets/llm.png")).decode() | |
| user_avatar = base64.b64encode(localStorage.load_file("assets/user.png")).decode() | |
| # Chat message formatting with avatars and enhanced HTML for readability | |
| def chat_message(role, message, avatar, timestamp): | |
| # Convert Markdown-style formatting to HTML | |
| formatted_message = ( | |
| message.replace("**", "<strong>").replace("**", "</strong>") | |
| .replace("###", "<h3>").replace("##", "<h2>") | |
| .replace("#", "<h1>").replace("\n", "<br>") | |
| .replace("```", "<pre><code>").replace("`", "</code></pre>") | |
| .replace("\n1. ", "<br>1. ") # For ordered lists starting with "1." | |
| .replace("\n2. ", "<br>2. ") | |
| .replace("\n3. ", "<br>3. ") | |
| .replace("\n4. ", "<br>4. ") | |
| .replace("\n5. ", "<br>5. ") | |
| ) | |
| return f""" | |
| <div class="chat-message {role}"> | |
| <img src="data:image/png;base64,{avatar}" class="avatar"/> | |
| <div> | |
| <div class="speech-bubble {role}-bubble">{formatted_message}</div> | |
| <div class="timestamp">{timestamp}</div> | |
| </div> | |
| </div> | |
| """ | |
| def resize_to_fit(image_path, max_width=800, max_height=600): | |
| """Resize image to fit editor while maintaining aspect ratio""" | |
| # Read image | |
| img = cv2.imread(image_path) | |
| if img is None: | |
| return None, 1.0 | |
| # Get original dimensions | |
| h, w = img.shape[:2] | |
| # Calculate scale factor to fit within max dimensions | |
| scale_w = max_width / w | |
| scale_h = max_height / h | |
| scale = min(scale_w, scale_h) | |
| # Always resize to fit the editor window | |
| new_w = int(w * scale) | |
| new_h = int(h * scale) | |
| resized = cv2.resize(img, (new_w, new_h)) | |
| return resized, scale | |
| # Main processing function for P&ID steps | |
| def process_pnid(image_file, progress=gr.Progress()): | |
| """Process P&ID document with real-time progress updates.""" | |
| try: | |
| if image_file is None: | |
| raise ValueError("No file uploaded. Please upload a file first.") | |
| progress_text = [] | |
| outputs = [None] * 9 # Changed from 8 to 9 to match UI outputs | |
| base_name = os.path.splitext(os.path.basename(image_file.name))[0] + "_page_1" | |
| # Initialize chat history with proper format | |
| chat_history = [{"role": "assistant", "content": "Welcome! Upload a P&ID to begin analysis."}] | |
| outputs[7] = chat_history # Chat history moved to index 7 | |
| def update_progress(step, message): | |
| progress_text.append(f"{datetime.now().strftime('%Y-%m-%d %H:%M:%S')} - {message}") | |
| outputs[0] = "\n".join(progress_text) # Progress text | |
| progress(step) | |
| # Initialize storage and results directory | |
| storage = StorageFactory.get_storage() | |
| results_dir = "results" | |
| os.makedirs(results_dir, exist_ok=True) | |
| # Clean results directory | |
| logger.info("Cleaned results directory: results") | |
| for file in os.listdir(results_dir): | |
| file_path = os.path.join(results_dir, file) | |
| try: | |
| if os.path.isfile(file_path): | |
| os.unlink(file_path) | |
| except Exception as e: | |
| logger.error(f"Error deleting file {file_path}: {str(e)}") | |
| # Step 1: File Upload (10%) | |
| logger.info(f"Processing file: {os.path.basename(image_file.name)}") | |
| update_progress(0.1, "Step 1/7: File uploaded successfully") | |
| yield outputs | |
| # Step 2: Document Processing - Get high quality PNG | |
| update_progress(0.2, "Step 2/7: Processing document...") | |
| doc_processor = DocumentProcessor(storage) | |
| processed_pages = doc_processor.process_document( | |
| file_path=image_file, | |
| output_dir=results_dir | |
| ) | |
| if not processed_pages: | |
| raise ValueError("No pages processed from document") | |
| # Use high quality PNG for everything | |
| high_quality_png = processed_pages[0] | |
| outputs[1] = high_quality_png # P&ID Tab shows original high quality | |
| update_progress(0.25, "Document loaded and displayed") | |
| yield outputs | |
| # Step 3: Symbol Detection using high quality PNG | |
| detection_image_path, detection_json_path, _, diagram_bbox = run_detection_with_optimal_threshold( | |
| high_quality_png, # Use high quality PNG | |
| results_dir=results_dir, | |
| file_name=os.path.basename(high_quality_png), | |
| storage=storage, | |
| resize_image=False # Don't resize | |
| ) | |
| outputs[2] = detection_image_path # Symbols Tab | |
| symbol_json_path = detection_json_path | |
| # Step 4: Text Detection using high quality PNG | |
| text_results, text_summary = process_drawing( | |
| high_quality_png, # Use high quality PNG | |
| results_dir, | |
| storage | |
| ) | |
| text_json_path = text_results['json_path'] | |
| outputs[3] = text_results['image_path'] # Tags Tab | |
| # Step 5: Line Detection (80%) | |
| update_progress(0.80, "Step 5/7: Line Detection") | |
| yield outputs | |
| try: | |
| # Initialize components | |
| debug_handler = DebugHandler(enabled=True, storage=storage) | |
| # Configure detectors | |
| line_config = LineConfig() | |
| point_config = PointConfig() | |
| junction_config = JunctionConfig() | |
| symbol_config = SymbolConfig( | |
| model_path="models/Intui_SDM_41.pt", | |
| confidence_threshold=0.5, | |
| nms_threshold=0.3 | |
| ) | |
| tag_config = TagConfig( | |
| model_path="models/tag_detection.json", | |
| confidence_threshold=0.5 | |
| ) | |
| # Create all required detectors | |
| symbol_detector = SymbolDetector( | |
| config=symbol_config, | |
| debug_handler=debug_handler | |
| ) | |
| tag_detector = TagDetector( | |
| config=tag_config, | |
| debug_handler=debug_handler | |
| ) | |
| line_detector = LineDetector( | |
| config=line_config, | |
| model_path="models/deeplsd_md.tar", | |
| model_config={"detect_lines": True}, | |
| device=torch.device("cuda"), | |
| debug_handler=debug_handler | |
| ) | |
| point_detector = PointDetector( | |
| config=point_config, | |
| debug_handler=debug_handler | |
| ) | |
| junction_detector = JunctionDetector( | |
| config=junction_config, | |
| debug_handler=debug_handler | |
| ) | |
| # Create pipeline with all detectors | |
| pipeline = DiagramDetectionPipeline( | |
| tag_detector=tag_detector, | |
| symbol_detector=symbol_detector, | |
| line_detector=line_detector, | |
| point_detector=point_detector, | |
| junction_detector=junction_detector, | |
| storage=storage, | |
| debug_handler=debug_handler | |
| ) | |
| # Run pipeline with original high-res image | |
| line_results = pipeline.run( | |
| image_path=high_quality_png, | |
| output_dir=results_dir, | |
| config=ImageConfig() | |
| ) | |
| line_json_path = line_results.json_path | |
| outputs[4] = line_results.image_path | |
| # Verify line detection output | |
| if not os.path.exists(line_json_path): | |
| raise ValueError(f"Line detection JSON not found: {line_json_path}") | |
| # Verify line detection JSON content | |
| with open(line_json_path, 'r') as f: | |
| line_data = json.load(f) | |
| if 'lines' not in line_data: | |
| raise ValueError(f"Invalid line detection data format in {line_json_path}") | |
| logger.info(f"Line detection completed successfully with {len(line_data['lines'])} lines") | |
| # Verify all required JSONs exist before aggregation | |
| required_jsons = { | |
| 'symbols': symbol_json_path, | |
| 'texts': text_json_path, | |
| 'lines': line_json_path | |
| } | |
| for name, path in required_jsons.items(): | |
| if not os.path.exists(path): | |
| raise ValueError(f"{name} JSON not found: {path}") | |
| # Verify JSON can be loaded | |
| with open(path, 'r') as f: | |
| data = json.load(f) | |
| logger.info(f"Loaded {name} JSON with {len(data.get('detections', data.get('lines', [])))} items") | |
| # Data Aggregation | |
| aggregator = DataAggregator(storage=storage) | |
| aggregated_result = aggregator.process_data( | |
| image_path=high_quality_png, | |
| output_dir=results_dir, | |
| symbols_path=symbol_json_path, | |
| texts_path=text_json_path, | |
| lines_path=line_json_path | |
| ) | |
| # Verify aggregation result before graph creation | |
| if not aggregated_result.get('success'): | |
| raise ValueError(f"Data aggregation failed: {aggregated_result.get('error')}") | |
| aggregated_json_path = aggregated_result['json_path'] | |
| if not os.path.exists(aggregated_json_path): | |
| raise ValueError(f"Aggregated JSON not found: {aggregated_json_path}") | |
| # Verify aggregated JSON content | |
| with open(aggregated_json_path, 'r') as f: | |
| aggregated_data = json.load(f) | |
| required_keys = ['nodes', 'edges', 'symbols', 'texts', 'lines'] | |
| missing_keys = [k for k in required_keys if k not in aggregated_data] | |
| if missing_keys: | |
| raise ValueError(f"Aggregated JSON missing required keys: {missing_keys}") | |
| logger.info("Aggregation completed successfully with:") | |
| logger.info(f"- {len(aggregated_data['nodes'])} nodes") | |
| logger.info(f"- {len(aggregated_data['edges'])} edges") | |
| # After aggregation, create graph visualization | |
| update_progress(0.85, "Step 6/7: Creating Knowledge Graph") | |
| try: | |
| # Create graph visualization | |
| graph_results = create_graph_visualization( | |
| json_path=aggregated_json_path, | |
| output_dir=results_dir, | |
| base_name=base_name, | |
| save_plot=True | |
| ) | |
| if not graph_results.get('success'): | |
| logger.error(f"Error in graph generation: {graph_results.get('error')}") | |
| raise Exception(graph_results.get('error')) | |
| graph_path = f"results/{base_name}_graph_visualization.png" | |
| if not os.path.exists(graph_path): | |
| raise Exception("Graph visualization file not created") | |
| update_progress(0.90, "Step 6/7: Knowledge Graph Created") | |
| except Exception as e: | |
| logger.error(f"Error creating graph visualization: {str(e)}") | |
| raise | |
| # Fix output assignments | |
| outputs[0] = progress_text # Progress text | |
| outputs[1] = high_quality_png # P&ID | |
| outputs[2] = detection_image_path # Symbols | |
| outputs[3] = text_results['image_path'] # Tags | |
| outputs[4] = line_results.image_path # Lines | |
| outputs[5] = f"results/{base_name}_aggregated.png" # Aggregated | |
| outputs[6] = graph_path # Graph visualization | |
| outputs[7] = chat_history # Chat | |
| outputs[8] = aggregated_json_path # JSON state | |
| # Update progress with all steps | |
| update_progress(0.95, "Step 7/7: Finalizing Results") | |
| chat_history = [{"role": "assistant", "content": "Processing complete! I can help answer questions about the P&ID contents."}] | |
| outputs[7] = chat_history | |
| update_progress(1.0, "β Processing Complete") | |
| yield outputs | |
| except Exception as e: | |
| # Update chat with error message | |
| chat_history = [{"role": "assistant", "content": f"Error during processing: {str(e)}"}] | |
| outputs[7] = chat_history | |
| raise | |
| except Exception as e: | |
| logger.error(f"Error in process_pnid: {str(e)}") | |
| logger.error(f"Stack trace:\n{traceback.format_exc()}") | |
| # Update chat with error message | |
| chat_history = [{"role": "assistant", "content": f"Error: {str(e)}"}] | |
| outputs[7] = chat_history | |
| raise | |
| # Separate function for Chat interaction | |
| def handle_user_message(user_input, chat_history, json_path_state): | |
| """Handle user messages and generate responses.""" | |
| try: | |
| if not user_input or not user_input.strip(): | |
| return chat_history | |
| # Add user message | |
| timestamp = get_timestamp() | |
| new_history = chat_history + chat_message("user", user_input, user_avatar, timestamp) | |
| # Check if json_path exists and is valid | |
| if not json_path_state or not os.path.exists(json_path_state): | |
| error_message = "Please upload and process a P&ID document first." | |
| return new_history + chat_message("assistant", error_message, agent_avatar, get_timestamp()) | |
| try: | |
| # Log for debugging | |
| logger.info(f"Sending question to assistant: {user_input}") | |
| logger.info(f"Using JSON path: {json_path_state}") | |
| # Generate response | |
| response = get_assistant_response(user_input, json_path_state) | |
| # Handle the response | |
| if isinstance(response, (str, dict)): | |
| response_text = str(response) | |
| else: | |
| try: | |
| # Try to get the first response from generator | |
| response_text = next(response) if hasattr(response, '__next__') else str(response) | |
| except StopIteration: | |
| response_text = "I apologize, but I couldn't generate a response." | |
| except Exception as e: | |
| logger.error(f"Error processing response: {str(e)}") | |
| response_text = "I apologize, but I encountered an error processing your request." | |
| logger.info(f"Generated response: {response_text}") | |
| if not response_text.strip(): | |
| response_text = "I apologize, but I couldn't generate a response. Please try asking your question differently." | |
| # Add response to chat history | |
| new_history += chat_message("assistant", response_text, agent_avatar, get_timestamp()) | |
| except Exception as e: | |
| logger.error(f"Error generating response: {str(e)}") | |
| logger.error(traceback.format_exc()) | |
| error_message = "I apologize, but I encountered an error processing your request. Please try again." | |
| new_history += chat_message("assistant", error_message, agent_avatar, get_timestamp()) | |
| return new_history | |
| except Exception as e: | |
| logger.error(f"Chat error: {str(e)}") | |
| logger.error(traceback.format_exc()) | |
| return chat_history + chat_message( | |
| "assistant", | |
| "I apologize, but something went wrong. Please try again.", | |
| agent_avatar, | |
| get_timestamp() | |
| ) | |
| # Update custom CSS | |
| custom_css = """ | |
| .full-height-row { | |
| height: calc(100vh - 150px); /* Adjusted height */ | |
| margin: 0; | |
| padding: 10px; | |
| } | |
| .upload-box { | |
| background: #2a2a2a; | |
| border-radius: 8px; | |
| padding: 15px; | |
| margin-bottom: 15px; | |
| border: 1px solid #3a3a3a; | |
| } | |
| .status-box-container { | |
| background: #2a2a2a; | |
| border-radius: 8px; | |
| padding: 15px; | |
| height: calc(100vh - 350px); /* Reduced height */ | |
| border: 1px solid #3a3a3a; | |
| margin-bottom: 15px; | |
| } | |
| .status-box { | |
| font-family: 'Courier New', monospace; | |
| font-size: 12px; | |
| line-height: 1.4; | |
| background-color: #1a1a1a; | |
| color: #00ff00; | |
| padding: 10px; | |
| border-radius: 5px; | |
| height: calc(100% - 40px); /* Adjust for header */ | |
| overflow-y: auto; | |
| white-space: pre-wrap; | |
| word-wrap: break-word; | |
| border: none; | |
| } | |
| .preview-tabs { | |
| height: calc(100vh - 100px); /* Increased container height from 200px */ | |
| background: #2a2a2a; | |
| border-radius: 8px; | |
| padding: 15px; | |
| border: 1px solid #3a3a3a; | |
| margin-bottom: 15px; | |
| } | |
| .chat-container { | |
| height: 100%; /* Take full height */ | |
| display: flex; | |
| flex-direction: column; | |
| background: #2a2a2a; | |
| border-radius: 8px; | |
| padding: 15px; | |
| border: 1px solid #3a3a3a; | |
| } | |
| .chatbox { | |
| flex: 1; /* Take remaining space */ | |
| overflow-y: auto; | |
| background: #1a1a1a; | |
| border-radius: 8px; | |
| padding: 15px; | |
| margin-bottom: 15px; | |
| color: #ffffff; | |
| min-height: 200px; /* Ensure minimum height */ | |
| } | |
| .chat-input-group { | |
| height: auto; /* Allow natural height */ | |
| min-height: 120px; /* Minimum height for input area */ | |
| background: #1a1a1a; | |
| border-radius: 8px; | |
| padding: 15px; | |
| margin-top: auto; /* Push to bottom */ | |
| } | |
| .chat-input { | |
| background: #2a2a2a; | |
| color: #ffffff; | |
| border: 1px solid #3a3a3a; | |
| border-radius: 5px; | |
| padding: 12px; | |
| min-height: 80px; | |
| width: 100%; | |
| margin-bottom: 10px; | |
| } | |
| .send-button { | |
| width: 100%; | |
| background: #4a4a4a; | |
| color: #ffffff; | |
| border-radius: 5px; | |
| border: none; | |
| padding: 12px; | |
| cursor: pointer; | |
| transition: background-color 0.3s; | |
| } | |
| .result-image { | |
| border-radius: 8px; | |
| box-shadow: 0 2px 4px rgba(0,0,0,0.1); | |
| margin: 10px 0; | |
| background: #ffffff; | |
| } | |
| .chat-message { | |
| display: flex; | |
| margin-bottom: 1rem; | |
| align-items: flex-start; | |
| } | |
| .chat-message .avatar { | |
| width: 40px; | |
| height: 40px; | |
| margin-right: 10px; | |
| border-radius: 50%; | |
| } | |
| .chat-message .speech-bubble { | |
| background: #2a2a2a; | |
| padding: 10px 15px; | |
| border-radius: 10px; | |
| max-width: 80%; | |
| margin-bottom: 5px; | |
| } | |
| .chat-message .timestamp { | |
| font-size: 0.8em; | |
| color: #666; | |
| } | |
| .logo-row { | |
| width: 100%; | |
| background-color: #1a1a1a; | |
| padding: 10px 0; | |
| margin: 0; | |
| border-bottom: 1px solid #3a3a3a; | |
| } | |
| """ | |
| def create_ui(): | |
| current_dir = os.path.dirname(os.path.abspath(__file__)) | |
| logo_path = os.path.join(current_dir, "assets", "intuigence.png") | |
| css = """ | |
| /* Theme colors */ | |
| :root { | |
| --orange-primary: #ff6b2b; | |
| --orange-hover: #ff8651; | |
| --orange-light: rgba(255, 107, 43, 0.1); | |
| } | |
| /* Logo styling */ | |
| .logo-container { | |
| padding: 10px 20px; | |
| margin-bottom: 10px; | |
| text-align: left; | |
| width: 100%; | |
| background: #1a1a1a; /* Match app background */ | |
| border-bottom: 1px solid #3a3a3a; | |
| } | |
| .logo-container img { | |
| max-height: 40px; | |
| width: auto; | |
| display: inline-block !important; | |
| } | |
| /* Hide download and fullscreen buttons for logo */ | |
| .logo-container .download-button, | |
| .logo-container .fullscreen-button { | |
| display: none !important; | |
| } | |
| /* Adjust main content padding */ | |
| .main-content { | |
| padding-top: 10px; | |
| } | |
| /* Custom orange theme */ | |
| .primary-button { | |
| background: var(--orange-primary) !important; | |
| color: white !important; | |
| border: none !important; | |
| } | |
| .primary-button:hover { | |
| background: var(--orange-hover) !important; | |
| } | |
| /* Tab styling */ | |
| .tabs > .tab-nav > button.selected { | |
| border-color: var(--orange-primary) !important; | |
| color: var(--orange-primary) !important; | |
| } | |
| .tabs > .tab-nav > button:hover { | |
| border-color: var(--orange-hover) !important; | |
| color: var(--orange-hover) !important; | |
| } | |
| /* File upload button */ | |
| .file-upload { | |
| background: var(--orange-primary) !important; | |
| } | |
| /* Progress bar */ | |
| .progress-bar > div { | |
| background: var(--orange-primary) !important; | |
| } | |
| /* Tags and labels */ | |
| .label-wrap { | |
| background: var(--orange-primary) !important; | |
| } | |
| /* Selected/active states */ | |
| .selected, .active, .focused { | |
| border-color: var(--orange-primary) !important; | |
| color: var(--orange-primary) !important; | |
| } | |
| /* Links and interactive elements */ | |
| a, .link, .interactive { | |
| color: var(--orange-primary) !important; | |
| } | |
| a:hover, .link:hover, .interactive:hover { | |
| color: var(--orange-hover) !important; | |
| } | |
| /* Input focus states */ | |
| input:focus, textarea:focus { | |
| border-color: var(--orange-primary) !important; | |
| box-shadow: 0 0 0 1px var(--orange-light) !important; | |
| } | |
| /* Checkbox and radio */ | |
| input[type="checkbox"]:checked, input[type="radio"]:checked { | |
| background-color: var(--orange-primary) !important; | |
| border-color: var(--orange-primary) !important; | |
| } | |
| """ | |
| with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo: | |
| # Logo row (before main content) | |
| with gr.Row(elem_classes="logo-container"): | |
| gr.Image( | |
| value=logo_path, | |
| show_label=False, | |
| container=False, | |
| interactive=False, | |
| show_download_button=False, | |
| show_share_button=False, | |
| height=40 | |
| ) | |
| # State for storing file path | |
| file_path = gr.State() | |
| json_path = gr.State() | |
| # Main content row | |
| with gr.Row(elem_classes="main-content"): | |
| # Left column - File Upload & Processing | |
| with gr.Column(scale=3, elem_classes="column-panel"): | |
| file_output = gr.File(label="Upload P&ID Document") | |
| process_button = gr.Button( | |
| "Process Document", | |
| elem_classes="primary-button" # Add custom class | |
| ) | |
| progress_output = gr.Textbox( | |
| label="Progress", | |
| value="Waiting for document...", | |
| interactive=False | |
| ) | |
| # Center column - Preview Panel | |
| with gr.Column(scale=5, elem_classes="column-panel preview-panel"): | |
| with gr.Tabs() as tabs: | |
| with gr.TabItem("P&ID"): | |
| input_image = gr.Image(type="filepath", label="Original") | |
| with gr.TabItem("Symbols"): | |
| symbol_image = gr.Image(type="filepath", label="Detected Symbols") | |
| with gr.TabItem("Tags"): | |
| text_image = gr.Image(type="filepath", label="Detected Tags") | |
| with gr.TabItem("Lines"): | |
| line_image = gr.Image(type="filepath", label="Detected Lines") | |
| with gr.TabItem("Aggregated"): | |
| aggregated_image = gr.Image(type="filepath", label="Aggregated Results") | |
| with gr.TabItem("Knowledge Graph"): | |
| graph_image = gr.Image(type="filepath", label="Knowledge Graph") | |
| # Right column - Chat Interface | |
| with gr.Column(scale=4, elem_classes="column-panel chat-panel", elem_id="chat-panel"): | |
| chat_history = gr.Chatbot( | |
| [], | |
| elem_classes="chat-history", | |
| height=400, | |
| show_label=False, | |
| type="messages", | |
| elem_id="chat-history" | |
| ) | |
| with gr.Row(): | |
| chat_input = gr.Textbox( | |
| placeholder="Ask me about the P&ID...", | |
| show_label=False, | |
| container=False | |
| ) | |
| chat_button = gr.Button( | |
| "Send", | |
| elem_classes="primary-button" # Add custom class | |
| ) | |
| def handle_chat(user_message, chat_history, json_path): | |
| if not user_message: | |
| return "", chat_history | |
| # Add user message | |
| chat_history = chat_history + [{"role": "user", "content": user_message}] | |
| try: | |
| # Get assistant response | |
| response = get_assistant_response(user_message, json_path) | |
| # Add assistant response | |
| chat_history = chat_history + [{"role": "assistant", "content": response}] | |
| except Exception as e: | |
| logger.error(f"Error in chat response: {str(e)}") | |
| chat_history = chat_history + [ | |
| {"role": "assistant", "content": "I apologize, but I encountered an error processing your request."} | |
| ] | |
| return "", chat_history | |
| # Connect UI elements | |
| chat_input.submit(handle_chat, [chat_input, chat_history, json_path], [chat_input, chat_history]) | |
| chat_button.click(handle_chat, [chat_input, chat_history, json_path], [chat_input, chat_history]) | |
| process_button.click( | |
| process_pnid, | |
| inputs=[file_output], | |
| outputs=[ | |
| progress_output, # Progress text (0) | |
| input_image, # P&ID (1) | |
| symbol_image, # Symbols (2) | |
| text_image, # Tags (3) | |
| line_image, # Lines (4) | |
| aggregated_image, # Aggregated (5) | |
| graph_image, # Graph (6) | |
| chat_history, # Chat (7) | |
| json_path # State (8) | |
| ], | |
| show_progress="hidden" # Hide progress in tabs | |
| ) | |
| return demo | |
| def main(): | |
| # Check for all required models | |
| required_models = [ | |
| 'models/yolo/yolov8n.pt', | |
| 'models/deeplsd/deeplsd_md.tar', | |
| 'models/doctr/craft_mlt_25k.pth', | |
| 'models/doctr/english_g2.pth', | |
| 'models/yolo/intui_LDM_01.pt' | |
| ] | |
| if any(not os.path.exists(model) for model in required_models): | |
| download_from_azure() | |
| demo = create_ui() | |
| # Remove HF Spaces conditional, just use local development settings | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7861, # Changed from 7860 | |
| share=True | |
| ) | |
| if __name__ == "__main__": | |
| main() |