Spaces:
Sleeping
Sleeping
| # file_upload_interface.py | |
| """File Upload Interface for Enhanced Verification Modes. | |
| Provides interface for uploading CSV files, validating content, | |
| batch processing with progress tracking, and exporting results. | |
| Requirements: 4.1, 4.3, 4.4, 4.5, 4.6, 4.7, 12.1, 12.2, 12.3, 12.4, 12.5 | |
| """ | |
| import gradio as gr | |
| import tempfile | |
| import os | |
| import uuid | |
| from typing import List, Dict, Tuple, Optional, Any | |
| from datetime import datetime | |
| from src.core.file_processing_service import FileProcessingService | |
| from src.core.verification_models import ( | |
| EnhancedVerificationSession, | |
| VerificationRecord, | |
| TestMessage, | |
| FileUploadResult, | |
| ) | |
| from src.core.verification_store import JSONVerificationStore | |
| from src.core.ai_client import AIClientManager | |
| from src.config.prompts import SYSTEM_PROMPT_ENTRY_CLASSIFIER | |
| from src.core.enhanced_progress_tracker import EnhancedProgressTracker, VerificationMode | |
| from src.interface.enhanced_progress_components import ProgressTrackingMixin | |
| from src.interface.ui_consistency_components import ( | |
| StandardizedComponents, | |
| ClassificationDisplay, | |
| ProgressDisplay, | |
| ErrorDisplay, | |
| SessionDisplay, | |
| HelpDisplay | |
| ) | |
| class FileUploadInterfaceController(ProgressTrackingMixin): | |
| """Controller for file upload mode interface.""" | |
| def __init__(self): | |
| """Initialize the file upload interface controller.""" | |
| super().__init__(VerificationMode.FILE_UPLOAD) | |
| self.file_processor = FileProcessingService() | |
| self.store = JSONVerificationStore() | |
| self.ai_client = AIClientManager() | |
| # Optional per-session model overrides (UI Model Settings tab) | |
| self.model_overrides = {} | |
| self.ai_client.set_model_overrides(self.model_overrides) | |
| # Optional per-session prompt overrides (UI Edit Prompts tab) | |
| self.prompt_overrides = {} | |
| self.ai_client.set_prompt_overrides(self.prompt_overrides) | |
| self.current_session = None | |
| self.current_file_result = None | |
| self.current_message_index = 0 | |
| self.batch_processing_start_time = None | |
| def set_model_overrides(self, overrides: Optional[Dict[str, str]] = None) -> None: | |
| """Set per-session model overrides from the UI.""" | |
| self.model_overrides = dict(overrides or {}) | |
| self.ai_client.set_model_overrides(self.model_overrides) | |
| def set_prompt_overrides(self, overrides: Optional[Dict[str, str]] = None) -> None: | |
| """Set per-session prompt overrides from the UI.""" | |
| self.prompt_overrides = dict(overrides or {}) | |
| self.ai_client.set_prompt_overrides(self.prompt_overrides) | |
| def process_uploaded_file(self, file_path: str) -> Tuple[bool, str, Optional[FileUploadResult], str]: | |
| """ | |
| Process an uploaded file and return validation results. | |
| Args: | |
| file_path: Path to the uploaded file | |
| Returns: | |
| Tuple of (success, status_message, file_result, preview_html) | |
| """ | |
| if not file_path or not file_path.endswith('.csv'): | |
| return False, "❌ No file uploaded", None, "" | |
| try: | |
| # Process the file | |
| file_result = self.file_processor.process_uploaded_file(file_path) | |
| if file_result.validation_errors: | |
| # File has validation errors | |
| error_details = self.file_processor.get_validation_error_details(file_result.validation_errors) | |
| error_html = self._format_validation_errors(error_details) | |
| status_msg = f"❌ File validation failed ({len(file_result.validation_errors)} errors)" | |
| return False, status_msg, file_result, error_html | |
| else: | |
| # File is valid - generate preview | |
| preview_html = self._generate_file_preview(file_result) | |
| status_msg = f"✅ File processed successfully: {file_result.valid_rows} valid test cases found" | |
| return True, status_msg, file_result, preview_html | |
| except Exception as e: | |
| error_msg = f"❌ Error processing file: {str(e)}" | |
| return False, error_msg, None, "" | |
| def _format_validation_errors(self, error_details: Dict[str, Any]) -> str: | |
| """ | |
| Format validation errors as HTML using standardized components. | |
| Args: | |
| error_details: Error details from file processor | |
| Returns: | |
| HTML string with formatted errors | |
| """ | |
| # Create main error message | |
| main_message = f"File validation failed ({error_details['total_errors']} errors)" | |
| # Prepare suggestions list | |
| suggestions = [] | |
| # Add first 10 errors as suggestions | |
| errors_to_show = error_details['errors'][:10] | |
| suggestions.extend(errors_to_show) | |
| if len(error_details['errors']) > 10: | |
| remaining = len(error_details['errors']) - 10 | |
| suggestions.append(f"... and {remaining} more errors") | |
| # Add format suggestions | |
| if error_details.get('suggestions'): | |
| suggestions.extend(error_details['suggestions']) | |
| # Add format help | |
| format_help = error_details.get('format_help', {}) | |
| if format_help: | |
| suggestions.extend([ | |
| f"Required columns: {', '.join(format_help.get('required_columns', []))}", | |
| f"Valid classifications: {', '.join(format_help.get('valid_classifications', []))}", | |
| "Supported delimiters (CSV): comma, semicolon, tab" | |
| ]) | |
| return ErrorDisplay.create_error_html_display( | |
| main_message, | |
| "error", | |
| suggestions | |
| ) | |
| def _generate_file_preview(self, file_result: FileUploadResult) -> str: | |
| """ | |
| Generate HTML preview of successfully processed file. | |
| Args: | |
| file_result: File processing result | |
| Returns: | |
| HTML string with file preview | |
| """ | |
| html = f""" | |
| <div style="font-family: system-ui; padding: 1em; background-color: #f0fdf4; border-left: 4px solid #16a34a; border-radius: 4px;"> | |
| <h4 style="color: #16a34a; margin-top: 0;">✅ File Preview: {file_result.original_filename}</h4> | |
| <div style="margin-bottom: 1em;"> | |
| <strong>File Statistics:</strong><br> | |
| • Format: {file_result.file_format.upper()}<br> | |
| • Total rows: {file_result.total_rows}<br> | |
| • Valid test cases: {file_result.valid_rows}<br> | |
| • Upload time: {file_result.upload_timestamp.strftime('%Y-%m-%d %H:%M:%S')} | |
| </div> | |
| """ | |
| if file_result.parsed_test_cases: | |
| html += """ | |
| <div style="margin-bottom: 1em;"> | |
| <strong>Sample Test Cases (first 5):</strong> | |
| </div> | |
| <div style="background-color: white; border-radius: 4px; padding: 0.5em; border: 1px solid #d1d5db;"> | |
| <table style="width: 100%; border-collapse: collapse;"> | |
| <thead> | |
| <tr style="background-color: #f9fafb;"> | |
| <th style="padding: 0.5em; text-align: left; border-bottom: 1px solid #e5e7eb;">#</th> | |
| <th style="padding: 0.5em; text-align: left; border-bottom: 1px solid #e5e7eb;">Message Preview</th> | |
| <th style="padding: 0.5em; text-align: left; border-bottom: 1px solid #e5e7eb;">Expected Classification</th> | |
| </tr> | |
| </thead> | |
| <tbody> | |
| """ | |
| # Show first 5 test cases | |
| for i, test_case in enumerate(file_result.parsed_test_cases[:5], 1): | |
| message_preview = test_case.text[:80] + "..." if len(test_case.text) > 80 else test_case.text | |
| classification_badge = self._get_classification_badge(test_case.pre_classified_label) | |
| html += f""" | |
| <tr> | |
| <td style="padding: 0.5em; border-bottom: 1px solid #f3f4f6;">{i}</td> | |
| <td style="padding: 0.5em; border-bottom: 1px solid #f3f4f6;">{message_preview}</td> | |
| <td style="padding: 0.5em; border-bottom: 1px solid #f3f4f6;">{classification_badge}</td> | |
| </tr> | |
| """ | |
| html += """ | |
| </tbody> | |
| </table> | |
| </div> | |
| """ | |
| html += """ | |
| <div style="margin-top: 1em; padding: 0.75em; background-color: #ecfdf5; border-radius: 4px; border: 1px solid #a7f3d0;"> | |
| <p style="margin: 0; color: #065f46;"> | |
| <strong>✅ Ready for batch processing!</strong><br> | |
| Click "Start Batch Processing" to begin verification of all test cases. | |
| </p> | |
| </div> | |
| </div> | |
| """ | |
| return html | |
| def _get_classification_badge(self, classification: str) -> str: | |
| """ | |
| Get HTML badge for classification using standardized components. | |
| Args: | |
| classification: Classification label | |
| Returns: | |
| HTML badge string | |
| """ | |
| return ClassificationDisplay.format_classification_html_badge(classification) | |
| def start_batch_processing(self, verifier_name: str, file_result: FileUploadResult) -> Tuple[bool, str, Optional[EnhancedVerificationSession]]: | |
| """ | |
| Start batch processing session. | |
| Args: | |
| verifier_name: Name of the verifier | |
| file_result: File processing result | |
| Returns: | |
| Tuple of (success, message, session) | |
| """ | |
| if not verifier_name.strip(): | |
| return False, "❌ Please enter your name to start verification", None | |
| if not file_result or not file_result.parsed_test_cases: | |
| return False, "❌ No valid test cases to process", None | |
| try: | |
| # Create enhanced verification session | |
| session_id = uuid.uuid4().hex | |
| session = EnhancedVerificationSession( | |
| session_id=session_id, | |
| verifier_name=verifier_name.strip(), | |
| dataset_id=file_result.file_id, | |
| dataset_name=f"File Upload: {file_result.original_filename}", | |
| mode_type="file_upload", | |
| mode_metadata={ | |
| "file_id": file_result.file_id, | |
| "original_filename": file_result.original_filename, | |
| "file_format": file_result.file_format, | |
| "total_file_rows": file_result.total_rows, | |
| "valid_file_rows": file_result.valid_rows, | |
| }, | |
| file_source=file_result.original_filename, | |
| total_messages=len(file_result.parsed_test_cases), | |
| message_queue=[tc.message_id for tc in file_result.parsed_test_cases], | |
| current_queue_index=0, | |
| ) | |
| # Save session | |
| self.store.save_session(session) | |
| # Set current session and file result | |
| self.current_session = session | |
| self.current_file_result = file_result | |
| self.current_message_index = 0 | |
| # Setup progress tracking for batch processing | |
| self.setup_progress_tracking(len(file_result.parsed_test_cases)) | |
| return True, f"✅ Batch processing started for {len(file_result.parsed_test_cases)} test cases", session | |
| except Exception as e: | |
| return False, f"❌ Error starting batch processing: {str(e)}", None | |
| def get_current_message_for_batch_processing(self) -> Tuple[Optional[TestMessage], Optional[Dict[str, Any]]]: | |
| """ | |
| Get current message for batch processing. | |
| Returns: | |
| Tuple of (test_message, classification_result) | |
| """ | |
| if not self.current_session or not self.current_file_result: | |
| return None, None | |
| if self.current_message_index >= len(self.current_file_result.parsed_test_cases): | |
| return None, None | |
| # Get current test message | |
| test_message = self.current_file_result.parsed_test_cases[self.current_message_index] | |
| try: | |
| # Record batch processing start time for progress tracking | |
| self.batch_processing_start_time = datetime.now() | |
| # Call AI classifier using the same approach as manual input | |
| user_prompt = f"Please analyze this patient message for spiritual distress:\n\n{test_message.text}" | |
| response = self.ai_client.call_entry_classifier_api( | |
| system_prompt=SYSTEM_PROMPT_ENTRY_CLASSIFIER, | |
| user_prompt=user_prompt, | |
| temperature=0.3, | |
| ) | |
| # Parse the response to extract classification details | |
| classification_result = self._parse_classification_response(response) | |
| return test_message, classification_result | |
| except Exception as e: | |
| # Return error result | |
| error_result = { | |
| "decision": "error", | |
| "confidence": 0.0, | |
| "indicators": [f"Classification error: {str(e)}"], | |
| "error": str(e) | |
| } | |
| return test_message, error_result | |
| def _parse_classification_response(self, response: str) -> Dict[str, Any]: | |
| """ | |
| Parse AI response to extract classification details. | |
| Args: | |
| response: Raw AI response | |
| Returns: | |
| Dictionary with classification details | |
| """ | |
| # Default classification structure | |
| classification = { | |
| "decision": "unknown", | |
| "confidence": 0.0, | |
| "indicators": [], | |
| "raw_response": response | |
| } | |
| # Simple parsing logic - look for key indicators in response | |
| response_lower = response.lower() | |
| # Determine decision based on keywords | |
| if "red" in response_lower or "severe" in response_lower or "high risk" in response_lower: | |
| classification["decision"] = "red" | |
| classification["confidence"] = 0.8 | |
| elif "yellow" in response_lower or "moderate" in response_lower or "potential" in response_lower: | |
| classification["decision"] = "yellow" | |
| classification["confidence"] = 0.7 | |
| elif "green" in response_lower or "low" in response_lower or "no distress" in response_lower: | |
| classification["decision"] = "green" | |
| classification["confidence"] = 0.9 | |
| # Extract indicators (simple keyword matching) | |
| indicators = [] | |
| indicator_keywords = [ | |
| "hopelessness", "despair", "meaninglessness", "isolation", | |
| "anger at god", "spiritual pain", "guilt", "shame", | |
| "questioning faith", "loss of purpose", "existential crisis" | |
| ] | |
| for keyword in indicator_keywords: | |
| if keyword in response_lower: | |
| indicators.append(keyword.title()) | |
| if not indicators: | |
| indicators = ["General spiritual assessment"] | |
| classification["indicators"] = indicators | |
| return classification | |
| def run_batch_classification(self, progress: Optional[gr.Progress] = None) -> Tuple[bool, str, Dict[str, Any]]: | |
| """Run classification for the whole uploaded dataset and persist results. | |
| File Upload Mode is already labeled (ground truth provided in the file), so we | |
| don't need interactive message-by-message verification. Instead, we: | |
| - classify every message | |
| - store the model output as reasoning in `verifier_notes` | |
| - mark each record as correct/incorrect by comparing to ground truth | |
| """ | |
| if not self.current_session or not self.current_file_result: | |
| return False, "❌ No active session", {} | |
| total = len(self.current_file_result.parsed_test_cases) | |
| if total == 0: | |
| return False, "❌ No messages to process", {} | |
| try: | |
| # Reset any prior run state | |
| self.current_session.verifications = [] | |
| self.current_session.verified_count = 0 | |
| self.current_session.correct_count = 0 | |
| self.current_session.incorrect_count = 0 | |
| self.current_session.verified_message_ids = [] | |
| self.setup_progress_tracking(total) | |
| for idx, test_message in enumerate(self.current_file_result.parsed_test_cases): | |
| if progress is not None: | |
| progress( | |
| (idx) / total, | |
| desc=f"Processing {idx + 1}/{total}" | |
| ) | |
| self.batch_processing_start_time = datetime.now() | |
| user_prompt = ( | |
| "Please analyze this patient message for spiritual distress:\n\n" | |
| f"{test_message.text}" | |
| ) | |
| raw_response = self.ai_client.call_entry_classifier_api( | |
| system_prompt=SYSTEM_PROMPT_ENTRY_CLASSIFIER, | |
| user_prompt=user_prompt, | |
| temperature=0.3, | |
| model_override=self.model_overrides.get("EntryClassifier"), | |
| ) | |
| classification_result = self._parse_classification_response(raw_response) | |
| classifier_decision = classification_result.get("decision", "green") | |
| if classifier_decision not in ["green", "yellow", "red"]: | |
| classifier_decision = "green" | |
| ground_truth = test_message.pre_classified_label | |
| if ground_truth not in ["green", "yellow", "red"]: | |
| ground_truth = "green" | |
| is_correct = classifier_decision == ground_truth | |
| verification_record = VerificationRecord( | |
| message_id=test_message.message_id, | |
| original_message=test_message.text, | |
| classifier_decision=classifier_decision, | |
| classifier_confidence=classification_result.get("confidence", 0.0), | |
| classifier_indicators=classification_result.get("indicators", []), | |
| ground_truth_label=ground_truth, | |
| verifier_notes=raw_response, # store full LLM output as reasoning | |
| is_correct=is_correct, | |
| ) | |
| self.current_session.verifications.append(verification_record) | |
| self.current_session.verified_count += 1 | |
| self.current_session.verified_message_ids.append(test_message.message_id) | |
| if is_correct: | |
| self.current_session.correct_count += 1 | |
| else: | |
| self.current_session.incorrect_count += 1 | |
| self.record_verification_with_timing(is_correct, self.batch_processing_start_time) | |
| self.current_session.current_queue_index = idx + 1 | |
| self.current_session.is_complete = True | |
| self.current_session.completed_at = datetime.now() | |
| if progress is not None: | |
| progress(1.0, desc=f"Completed {total}/{total}") | |
| self.store.save_session(self.current_session) | |
| accuracy = ( | |
| (self.current_session.correct_count / self.current_session.verified_count * 100) | |
| if self.current_session.verified_count | |
| else 0 | |
| ) | |
| stats = { | |
| "processed": self.current_session.verified_count, | |
| "total": total, | |
| "correct": self.current_session.correct_count, | |
| "incorrect": self.current_session.incorrect_count, | |
| "accuracy": accuracy, | |
| "is_complete": True, | |
| } | |
| return True, f"✅ Batch classification completed. Accuracy: {accuracy:.1f}%", stats | |
| except Exception as e: | |
| return False, f"❌ Error during batch classification: {str(e)}", {} | |
| def export_batch_results_with_reasoning(self, format_type: str) -> Tuple[bool, str, Optional[str]]: | |
| """Export results including LLM reasoning. | |
| We rely on `verifier_notes` field to carry reasoning (raw model output). | |
| """ | |
| return self.export_batch_results(format_type) | |
| def submit_batch_verification(self, is_correct: bool, correction: Optional[str] = None, notes: str = "") -> Tuple[bool, str, Dict[str, Any]]: | |
| """ | |
| Submit verification for current message in batch processing. | |
| Args: | |
| is_correct: Whether the classification is correct | |
| correction: Correct classification if incorrect | |
| notes: Additional notes | |
| Returns: | |
| Tuple of (success, message, session_stats) | |
| """ | |
| if not self.current_session or not self.current_file_result: | |
| return False, "❌ No active batch processing session", {} | |
| if self.current_message_index >= len(self.current_file_result.parsed_test_cases): | |
| return False, "❌ No more messages to process", {} | |
| try: | |
| # Get current test message and classification | |
| test_message = self.current_file_result.parsed_test_cases[self.current_message_index] | |
| current_message, classification_result = self.get_current_message_for_batch_processing() | |
| if not current_message or not classification_result: | |
| return False, "❌ Error getting current message", {} | |
| # Create verification record | |
| # Ensure valid classification values (green, yellow, red only) | |
| classifier_decision = classification_result.get("decision", "green") | |
| if classifier_decision not in ["green", "yellow", "red"]: | |
| classifier_decision = "green" # Safe fallback | |
| ground_truth = correction if correction else test_message.pre_classified_label | |
| if ground_truth not in ["green", "yellow", "red"]: | |
| ground_truth = "green" # Safe fallback | |
| verification_record = VerificationRecord( | |
| message_id=test_message.message_id, | |
| original_message=test_message.text, | |
| classifier_decision=classifier_decision, | |
| classifier_confidence=classification_result.get("confidence", 0.0), | |
| classifier_indicators=classification_result.get("indicators", []), | |
| ground_truth_label=ground_truth, | |
| verifier_notes=notes, | |
| is_correct=is_correct, | |
| ) | |
| # Add to session | |
| self.current_session.verifications.append(verification_record) | |
| self.current_session.verified_count += 1 | |
| self.current_session.verified_message_ids.append(test_message.message_id) | |
| if is_correct: | |
| self.current_session.correct_count += 1 | |
| else: | |
| self.current_session.incorrect_count += 1 | |
| # Record verification with timing for progress tracking | |
| self.record_verification_with_timing(is_correct, self.batch_processing_start_time) | |
| # Move to next message | |
| self.current_message_index += 1 | |
| self.current_session.current_queue_index = self.current_message_index | |
| # Check if session is complete | |
| if self.current_message_index >= len(self.current_file_result.parsed_test_cases): | |
| self.current_session.is_complete = True | |
| self.current_session.completed_at = datetime.now() | |
| # Save session | |
| self.store.save_session(self.current_session) | |
| # Calculate stats | |
| stats = { | |
| "processed": self.current_session.verified_count, | |
| "total": self.current_session.total_messages, | |
| "correct": self.current_session.correct_count, | |
| "incorrect": self.current_session.incorrect_count, | |
| "accuracy": (self.current_session.correct_count / self.current_session.verified_count * 100) if self.current_session.verified_count > 0 else 0, | |
| "is_complete": self.current_session.is_complete, | |
| } | |
| if self.current_session.is_complete: | |
| message = f"✅ Batch processing completed! Final accuracy: {stats['accuracy']:.1f}%" | |
| else: | |
| message = f"✅ Verification recorded. Progress: {stats['processed']}/{stats['total']}" | |
| return True, message, stats | |
| except Exception as e: | |
| return False, f"❌ Error submitting verification: {str(e)}", {} | |
| def export_batch_results(self, format_type: str) -> Tuple[bool, str, Optional[str]]: | |
| """ | |
| Export batch processing results. | |
| Args: | |
| format_type: Export format ("csv", "json") | |
| Returns: | |
| Tuple of (success, message, file_path) | |
| """ | |
| if not self.current_session: | |
| return False, "❌ No active session to export", None | |
| try: | |
| if format_type == "csv": | |
| content = self.store.export_to_csv(self.current_session.session_id) | |
| # Save to temporary file | |
| temp_file = tempfile.NamedTemporaryFile(mode='w', suffix='.csv', delete=False) | |
| temp_file.write(content) | |
| temp_file.close() | |
| file_path = temp_file.name | |
| elif format_type == "json": | |
| content = self.store.export_to_json(self.current_session.session_id) | |
| # Save to temporary file | |
| temp_file = tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) | |
| temp_file.write(content) | |
| temp_file.close() | |
| file_path = temp_file.name | |
| else: | |
| return False, f"❌ Unsupported export format: {format_type}", None | |
| if file_path: | |
| return True, f"✅ Results exported to {format_type.upper()} format", file_path | |
| else: | |
| return False, f"❌ Failed to export results in {format_type.upper()} format", None | |
| except Exception as e: | |
| return False, f"❌ Export error: {str(e)}", None | |
| def get_enhanced_progress_info(self) -> Dict[str, Any]: | |
| """ | |
| Get enhanced progress information for display. | |
| Returns: | |
| Dictionary containing progress information | |
| """ | |
| if not hasattr(self, 'progress_tracker') or not self.progress_tracker: | |
| return { | |
| "progress_display": "📊 Progress: Ready to start", | |
| "accuracy_display": "🎯 Current Accuracy: No verifications yet", | |
| "speed_display": "⚡ Processing Speed: Calculating...", | |
| "time_display": "⏱️ Time: Not started", | |
| "error_display": "", | |
| "stats_summary": "No active session" | |
| } | |
| return { | |
| "progress_display": self.progress_tracker.get_progress_display(), | |
| "accuracy_display": self.progress_tracker.get_accuracy_display(), | |
| "speed_display": self.progress_tracker.get_processing_speed_display(), | |
| "time_display": self.progress_tracker.get_time_tracking_display(), | |
| "error_display": self.progress_tracker.get_error_display(), | |
| "stats_summary": self._get_session_stats_summary() | |
| } | |
| def record_batch_processing_error(self, error_message: str, can_continue: bool = True) -> None: | |
| """ | |
| Record a batch processing error. | |
| Args: | |
| error_message: Description of the error | |
| can_continue: Whether processing can continue | |
| """ | |
| if hasattr(self, 'progress_tracker') and self.progress_tracker: | |
| self.progress_tracker.record_error(error_message, can_continue) | |
| def pause_batch_processing(self) -> Tuple[bool, bool, bool]: | |
| """ | |
| Pause the current batch processing session. | |
| Returns: | |
| Tuple of control button visibility states | |
| """ | |
| if hasattr(self, 'progress_tracker') and self.progress_tracker: | |
| return self.handle_session_pause() | |
| return False, False, True | |
| def resume_batch_processing(self) -> Tuple[bool, bool, bool]: | |
| """ | |
| Resume the current batch processing session. | |
| Returns: | |
| Tuple of control button visibility states | |
| """ | |
| if hasattr(self, 'progress_tracker') and self.progress_tracker: | |
| return self.handle_session_resume() | |
| return True, False, True | |
| def _get_session_stats_summary(self) -> str: | |
| """Get formatted session statistics summary.""" | |
| if not self.current_session: | |
| return "No active session" | |
| accuracy = (self.current_session.correct_count / self.current_session.verified_count * 100) if self.current_session.verified_count > 0 else 0 | |
| return f""" | |
| **Batch Processing Session:** | |
| - File: {self.current_session.file_source or 'Unknown'} | |
| - Processed: {self.current_session.verified_count}/{self.current_session.total_messages} | |
| - Accuracy: {accuracy:.1f}% | |
| - Correct: {self.current_session.correct_count} | |
| - Incorrect: {self.current_session.incorrect_count} | |
| - Processing Speed: {self.progress_tracker.get_processing_speed_display() if hasattr(self, 'progress_tracker') else 'Unknown'} | |
| """ | |
| def get_template_files(self) -> Tuple[str, Optional[bytes]]: | |
| """ | |
| Get template files for download. | |
| Returns: | |
| Tuple of (csv_content, xlsx_bytes) | |
| """ | |
| csv_content = self.file_processor.generate_csv_template() | |
| xlsx_bytes = None # Removed XLSX template generation | |
| return csv_content, xlsx_bytes | |
| def create_file_upload_interface(model_overrides_state: Optional[gr.State] = None) -> gr.Blocks: | |
| """ | |
| Create the complete file upload mode interface. | |
| Returns: | |
| Gradio Blocks component for file upload mode | |
| """ | |
| controller = FileUploadInterfaceController() | |
| # Apply any provided model overrides at build time. | |
| # Note: this is safe even if the state is mutated later, because the click | |
| # handlers also refresh overrides before calls. | |
| if model_overrides_state is not None: | |
| try: | |
| controller.set_model_overrides(model_overrides_state.value or {}) | |
| except Exception: | |
| # Don't fail UI creation if state isn't initialized yet | |
| pass | |
| with gr.Blocks() as file_upload_interface: | |
| # Headers and back button are in parent interface | |
| # Application state | |
| current_file_result_state = gr.State(value=None) | |
| current_session_state = gr.State(value=None) | |
| # File upload section | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| gr.Markdown("## 📤 Upload Test File") | |
| file_upload = gr.File( | |
| label="Select CSV File", | |
| file_types=[".csv"], | |
| type="filepath" | |
| ) | |
| with gr.Row(): | |
| process_file_btn = StandardizedComponents.create_primary_button("Process File", "🔍") | |
| process_file_btn.scale = 2 | |
| clear_file_btn = StandardizedComponents.create_secondary_button("Clear", "🗑️") | |
| clear_file_btn.scale = 1 | |
| with gr.Column(scale=1): | |
| gr.Markdown("## 📋 Template Files") | |
| gr.Markdown("Download template files to see the required format:") | |
| with gr.Column(): | |
| # Use DownloadButton for direct file download | |
| download_csv_template_btn = gr.DownloadButton( | |
| "📄 Download CSV Template", | |
| value="exports/template_test_messages.csv", | |
| size="sm" | |
| ) | |
| # XLSX template removed (CSV-only workflow) | |
| gr.Markdown("### 📝 Format Requirements") | |
| gr.Markdown(""" | |
| **Required columns:** | |
| - `message` (or `text`): Patient message text | |
| - `expected_classification` (or `classification`): Expected result | |
| **Valid classifications:** | |
| - `green`: No distress | |
| - `yellow`: Potential distress | |
| - `red`: Severe distress | |
| **Supported formats:** | |
| - CSV with comma, semicolon, or tab delimiters | |
| """) | |
| # File processing results section | |
| file_results_section = gr.Row(visible=False) | |
| with file_results_section: | |
| with gr.Column(): | |
| gr.Markdown("## 📊 File Processing Results") | |
| file_preview_display = gr.HTML( | |
| value="", | |
| label="File Preview" | |
| ) | |
| # Batch processing section | |
| batch_processing_section = gr.Row(visible=False) | |
| with batch_processing_section: | |
| with gr.Column(): | |
| gr.Markdown("## 🚀 Batch Processing") | |
| # Processing controls | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| verifier_name_input = gr.Textbox( | |
| label="Verifier Name", | |
| placeholder="Enter your name...", | |
| interactive=True | |
| ) | |
| with gr.Column(scale=1): | |
| start_batch_btn = StandardizedComponents.create_primary_button( | |
| "Start Batch Processing", | |
| "🚀", | |
| "lg" | |
| ) | |
| # Visual progress bar (updates during batch classification) | |
| batch_progress_bar = gr.Progress() | |
| # Progress text display | |
| batch_progress_display = gr.Markdown( | |
| "Ready to start batch processing", | |
| label="Progress" | |
| ) | |
| # Export results (visible after batch completes) | |
| gr.Markdown("### 💾 Download Results") | |
| with gr.Row(): | |
| export_csv_btn = gr.DownloadButton( | |
| label="Download CSV", | |
| variant="secondary", | |
| visible=False, | |
| ) | |
| export_json_btn = gr.DownloadButton( | |
| label="Download JSON", | |
| variant="secondary", | |
| visible=False, | |
| ) | |
| # Message processing section (initially hidden) | |
| message_processing_section = gr.Row(visible=False) | |
| with message_processing_section: | |
| with gr.Column(scale=2): | |
| # Current message display | |
| current_message_display = gr.Textbox( | |
| label="📝 Current Message", | |
| interactive=False, | |
| lines=4 | |
| ) | |
| # Expected vs Actual comparison | |
| with gr.Row(): | |
| with gr.Column(): | |
| expected_classification_display = gr.Markdown( | |
| "Expected: Loading...", | |
| label="📋 Expected Classification" | |
| ) | |
| with gr.Column(): | |
| actual_classification_display = gr.Markdown( | |
| "Actual: Loading...", | |
| label="🎯 AI Classification" | |
| ) | |
| # Classification details | |
| classifier_confidence_display = gr.Markdown( | |
| "Confidence: Loading...", | |
| label="📊 Confidence Level" | |
| ) | |
| classifier_indicators_display = gr.Markdown( | |
| "Indicators: Loading...", | |
| label="🔍 Detected Indicators" | |
| ) | |
| # Verification buttons | |
| with gr.Row(): | |
| correct_classification_btn = StandardizedComponents.create_primary_button("Correct", "✓") | |
| correct_classification_btn.scale = 1 | |
| incorrect_classification_btn = StandardizedComponents.create_stop_button("Incorrect", "✗") | |
| incorrect_classification_btn.scale = 1 | |
| # Correction section (initially hidden) | |
| correction_section = gr.Row(visible=False) | |
| with correction_section: | |
| correction_selector = ClassificationDisplay.create_classification_radio() | |
| correction_notes = gr.Textbox( | |
| label="Notes (Optional)", | |
| placeholder="Why is this incorrect?", | |
| lines=2, | |
| interactive=True | |
| ) | |
| submit_correction_btn = StandardizedComponents.create_primary_button("Submit", "✓") | |
| with gr.Column(scale=1): | |
| # Batch statistics | |
| gr.Markdown("### 📊 Batch Statistics") | |
| batch_stats_display = gr.Markdown( | |
| """ | |
| **Messages Processed:** 0 | |
| **Correct Classifications:** 0 | |
| **Incorrect Classifications:** 0 | |
| **Accuracy:** 0% | |
| **Processing Speed:** 0 msg/min | |
| """, | |
| label="Statistics" | |
| ) | |
| gr.Markdown("### 💾 Export Results") | |
| gr.Markdown("Download buttons appear in the 'Batch Processing' section after completion.") | |
| # Status messages | |
| status_message = gr.Markdown("", visible=True) | |
| # Event handlers | |
| def on_process_file(file_path): | |
| """Handle file processing.""" | |
| if not file_path: | |
| return ( | |
| gr.Row(visible=False), # file_results_section | |
| gr.Row(visible=False), # batch_processing_section | |
| "", # file_preview_display | |
| None, # current_file_result_state | |
| "❌ Please select a file to upload" # status_message | |
| ) | |
| success, status_msg, file_result, preview_html = controller.process_uploaded_file(file_path) | |
| if success: | |
| return ( | |
| gr.Row(visible=True), # file_results_section | |
| gr.Row(visible=True), # batch_processing_section | |
| preview_html, # file_preview_display | |
| file_result, # current_file_result_state | |
| status_msg # status_message | |
| ) | |
| else: | |
| return ( | |
| gr.Row(visible=True), # file_results_section | |
| gr.Row(visible=False), # batch_processing_section | |
| preview_html, # file_preview_display | |
| file_result, # current_file_result_state | |
| status_msg # status_message | |
| ) | |
| def on_clear_file(): | |
| """Handle file clearing.""" | |
| return ( | |
| gr.Row(visible=False), # file_results_section | |
| gr.Row(visible=False), # batch_processing_section | |
| gr.Row(visible=False), # message_processing_section | |
| "", # file_preview_display | |
| None, # current_file_result_state | |
| None, # current_session_state | |
| gr.DownloadButton(visible=False), # export_csv_btn | |
| gr.DownloadButton(visible=False), # export_json_btn | |
| "File cleared" # status_message | |
| ) | |
| def on_start_batch_processing(verifier_name, file_result): | |
| """Handle starting batch processing.""" | |
| if not file_result: | |
| return ( | |
| gr.Row(visible=False), # message_processing_section | |
| None, # current_session_state | |
| "❌ No file processed" # status_message | |
| ) | |
| success, message, session = controller.start_batch_processing(verifier_name, file_result) | |
| if success: | |
| # Simplified behavior: dataset is already labeled, so run full batch | |
| # classification immediately and generate results for export. | |
| run_ok, run_msg, stats = controller.run_batch_classification(progress=batch_progress_bar) | |
| if run_ok: | |
| progress_text = f"✅ Completed: {stats.get('processed', 0)}/{stats.get('total', 0)} messages" | |
| return ( | |
| gr.Row(visible=False), # message_processing_section (not used in simplified flow) | |
| session, # current_session_state | |
| "", # current_message_display | |
| "", # expected_classification_display | |
| "", # actual_classification_display | |
| "", # classifier_confidence_display | |
| "", # classifier_indicators_display | |
| progress_text, # batch_progress_display | |
| gr.DownloadButton(visible=True), # export_csv_btn | |
| gr.DownloadButton(visible=True), # export_json_btn | |
| run_msg # status_message | |
| ) | |
| return ( | |
| gr.Row(visible=False), # message_processing_section | |
| session, # current_session_state | |
| "", # current_message_display | |
| "", # expected_classification_display | |
| "", # actual_classification_display | |
| "", # classifier_confidence_display | |
| "", # classifier_indicators_display | |
| "❌ Batch classification failed", # batch_progress_display | |
| gr.DownloadButton(visible=False), # export_csv_btn | |
| gr.DownloadButton(visible=False), # export_json_btn | |
| run_msg # status_message | |
| ) | |
| else: | |
| return ( | |
| gr.Row(visible=False), # message_processing_section | |
| None, # current_session_state | |
| "", # current_message_display | |
| "", # expected_classification_display | |
| "", # actual_classification_display | |
| "", # classifier_confidence_display | |
| "", # classifier_indicators_display | |
| "", # batch_progress_display | |
| gr.DownloadButton(visible=False), # export_csv_btn | |
| gr.DownloadButton(visible=False), # export_json_btn | |
| message # status_message | |
| ) | |
| def on_correct_classification(): | |
| """Handle correct classification feedback.""" | |
| success, message, stats = controller.submit_batch_verification(True) | |
| if success and not stats.get('is_complete', False): | |
| # Load next message | |
| current_message, classification_result = controller.get_current_message_for_batch_processing() | |
| if current_message: | |
| expected_badge = controller._get_classification_badge(current_message.pre_classified_label) | |
| actual_badge = controller._get_classification_badge(classification_result.get('decision', 'unknown')) | |
| confidence_text = f"📊 {classification_result.get('confidence', 0) * 100:.1f}% confident" | |
| indicators_text = "🔍 " + ", ".join(classification_result.get('indicators', ['No indicators'])) | |
| progress_text = f"Progress: {stats['processed'] + 1} of {stats['total']} messages" | |
| stats_text = f""" | |
| **Messages Processed:** {stats['processed']} | |
| **Correct Classifications:** {stats['correct']} | |
| **Incorrect Classifications:** {stats['incorrect']} | |
| **Accuracy:** {stats['accuracy']:.1f}% | |
| **Processing Speed:** {stats['processed']} msg/min | |
| """ | |
| return ( | |
| current_message.text, # current_message_display | |
| f"Expected: {expected_badge}", # expected_classification_display | |
| f"AI Result: {actual_badge}", # actual_classification_display | |
| confidence_text, # classifier_confidence_display | |
| indicators_text, # classifier_indicators_display | |
| progress_text, # batch_progress_display | |
| stats_text, # batch_stats_display | |
| gr.Row(visible=False), # correction_section | |
| gr.DownloadButton(visible=True), # export_csv_btn | |
| gr.DownloadButton(visible=True), # export_json_btn | |
| message # status_message | |
| ) | |
| else: | |
| # Batch complete | |
| stats_text = f""" | |
| **Batch Complete!** | |
| **Messages Processed:** {stats['processed']} | |
| **Correct Classifications:** {stats['correct']} | |
| **Incorrect Classifications:** {stats['incorrect']} | |
| **Final Accuracy:** {stats['accuracy']:.1f}% | |
| """ | |
| return ( | |
| "Batch processing completed!", # current_message_display | |
| "✅ All messages processed", # expected_classification_display | |
| "", # actual_classification_display | |
| "", # classifier_confidence_display | |
| "", # classifier_indicators_display | |
| "✅ Batch processing complete", # batch_progress_display | |
| stats_text, # batch_stats_display | |
| gr.Row(visible=False), # correction_section | |
| gr.DownloadButton(visible=True), # export_csv_btn | |
| gr.DownloadButton(visible=True), # export_json_btn | |
| message # status_message | |
| ) | |
| else: | |
| return ( | |
| gr.Textbox(value=""), # current_message_display (no change) | |
| gr.Markdown(value=""), # expected_classification_display (no change) | |
| gr.Markdown(value=""), # actual_classification_display (no change) | |
| gr.Markdown(value=""), # classifier_confidence_display (no change) | |
| gr.Markdown(value=""), # classifier_indicators_display (no change) | |
| gr.Markdown(value=""), # batch_progress_display (no change) | |
| gr.Markdown(value=""), # batch_stats_display (no change) | |
| gr.Row(visible=False), # correction_section | |
| gr.DownloadButton(visible=False), # export_csv_btn | |
| gr.DownloadButton(visible=False), # export_json_btn | |
| message # status_message | |
| ) | |
| def on_incorrect_classification(): | |
| """Handle incorrect classification - show correction options.""" | |
| return ( | |
| gr.Row(visible=True), # correction_section | |
| "Please select the correct classification" # status_message | |
| ) | |
| def on_submit_correction(correction, notes): | |
| """Handle correction submission.""" | |
| success, message, stats = controller.submit_batch_verification( | |
| False, correction, notes | |
| ) | |
| if success and not stats.get('is_complete', False): | |
| # Load next message | |
| current_message, classification_result = controller.get_current_message_for_batch_processing() | |
| if current_message: | |
| expected_badge = controller._get_classification_badge(current_message.pre_classified_label) | |
| actual_badge = controller._get_classification_badge(classification_result.get('decision', 'unknown')) | |
| confidence_text = f"📊 {classification_result.get('confidence', 0) * 100:.1f}% confident" | |
| indicators_text = "🔍 " + ", ".join(classification_result.get('indicators', ['No indicators'])) | |
| progress_text = f"Progress: {stats['processed'] + 1} of {stats['total']} messages" | |
| stats_text = f""" | |
| **Messages Processed:** {stats['processed']} | |
| **Correct Classifications:** {stats['correct']} | |
| **Incorrect Classifications:** {stats['incorrect']} | |
| **Accuracy:** {stats['accuracy']:.1f}% | |
| **Processing Speed:** {stats['processed']} msg/min | |
| """ | |
| return ( | |
| current_message.text, # current_message_display | |
| f"Expected: {expected_badge}", # expected_classification_display | |
| f"AI Result: {actual_badge}", # actual_classification_display | |
| confidence_text, # classifier_confidence_display | |
| indicators_text, # classifier_indicators_display | |
| progress_text, # batch_progress_display | |
| stats_text, # batch_stats_display | |
| gr.Row(visible=False), # correction_section | |
| "", # correction_notes (clear) | |
| gr.DownloadButton(visible=True), # export_csv_btn | |
| gr.DownloadButton(visible=True), # export_json_btn | |
| message # status_message | |
| ) | |
| else: | |
| # Batch complete | |
| stats_text = f""" | |
| **Batch Complete!** | |
| **Messages Processed:** {stats['processed']} | |
| **Correct Classifications:** {stats['correct']} | |
| **Incorrect Classifications:** {stats['incorrect']} | |
| **Final Accuracy:** {stats['accuracy']:.1f}% | |
| """ | |
| return ( | |
| "Batch processing completed!", # current_message_display | |
| "✅ All messages processed", # expected_classification_display | |
| "", # actual_classification_display | |
| "", # classifier_confidence_display | |
| "", # classifier_indicators_display | |
| "✅ Batch processing complete", # batch_progress_display | |
| stats_text, # batch_stats_display | |
| gr.Row(visible=False), # correction_section | |
| "", # correction_notes (clear) | |
| gr.DownloadButton(visible=True), # export_csv_btn | |
| gr.DownloadButton(visible=True), # export_json_btn | |
| message # status_message | |
| ) | |
| else: | |
| return ( | |
| gr.Textbox(value=""), # current_message_display (no change) | |
| gr.Markdown(value=""), # expected_classification_display (no change) | |
| gr.Markdown(value=""), # actual_classification_display (no change) | |
| gr.Markdown(value=""), # classifier_confidence_display (no change) | |
| gr.Markdown(value=""), # classifier_indicators_display (no change) | |
| gr.Markdown(value=""), # batch_progress_display (no change) | |
| gr.Markdown(value=""), # batch_stats_display (no change) | |
| gr.Row(visible=True), # correction_section (keep visible) | |
| notes, # correction_notes (keep) | |
| gr.DownloadButton(visible=False), # export_csv_btn | |
| gr.DownloadButton(visible=False), # export_json_btn | |
| message # status_message | |
| ) | |
| def on_export_results_file(format_type): | |
| """Handle results export and return the generated file for download.""" | |
| success, message, file_path = controller.export_batch_results(format_type) | |
| if success and file_path: | |
| return file_path | |
| return None | |
| def on_download_csv_template(): | |
| """Handle CSV template download.""" | |
| csv_content, _ = controller.get_template_files() | |
| # Create temporary file | |
| temp_file = tempfile.NamedTemporaryFile(mode='w', suffix='.csv', delete=False) | |
| temp_file.write(csv_content) | |
| temp_file.close() | |
| return temp_file.name | |
| # Bind event handlers | |
| process_file_btn.click( | |
| on_process_file, | |
| inputs=[file_upload], | |
| outputs=[ | |
| file_results_section, | |
| batch_processing_section, | |
| file_preview_display, | |
| current_file_result_state, | |
| status_message | |
| ] | |
| ) | |
| clear_file_btn.click( | |
| on_clear_file, | |
| outputs=[ | |
| file_results_section, | |
| batch_processing_section, | |
| message_processing_section, | |
| file_preview_display, | |
| current_file_result_state, | |
| current_session_state, | |
| export_csv_btn, | |
| export_json_btn, | |
| status_message | |
| ] | |
| ) | |
| start_batch_btn.click( | |
| on_start_batch_processing, | |
| inputs=[verifier_name_input, current_file_result_state], | |
| outputs=[ | |
| message_processing_section, | |
| current_session_state, | |
| current_message_display, | |
| expected_classification_display, | |
| actual_classification_display, | |
| classifier_confidence_display, | |
| classifier_indicators_display, | |
| batch_progress_display, | |
| export_csv_btn, | |
| export_json_btn, | |
| status_message | |
| ] | |
| ) | |
| correct_classification_btn.click( | |
| on_correct_classification, | |
| outputs=[ | |
| current_message_display, | |
| expected_classification_display, | |
| actual_classification_display, | |
| classifier_confidence_display, | |
| classifier_indicators_display, | |
| batch_progress_display, | |
| batch_stats_display, | |
| correction_section, | |
| export_csv_btn, | |
| export_json_btn, | |
| status_message | |
| ] | |
| ) | |
| incorrect_classification_btn.click( | |
| on_incorrect_classification, | |
| outputs=[correction_section, status_message] | |
| ) | |
| submit_correction_btn.click( | |
| on_submit_correction, | |
| inputs=[correction_selector, correction_notes], | |
| outputs=[ | |
| current_message_display, | |
| expected_classification_display, | |
| actual_classification_display, | |
| classifier_confidence_display, | |
| classifier_indicators_display, | |
| batch_progress_display, | |
| batch_stats_display, | |
| correction_section, | |
| correction_notes, | |
| export_csv_btn, | |
| export_json_btn, | |
| status_message | |
| ] | |
| ) | |
| export_csv_btn.click(lambda: on_export_results_file("csv"), outputs=[export_csv_btn]) | |
| export_json_btn.click(lambda: on_export_results_file("json"), outputs=[export_json_btn]) | |
| download_csv_template_btn.click( | |
| on_download_csv_template, | |
| outputs=[gr.File(visible=False)] | |
| ) | |
| return file_upload_interface |