Spaces:
Sleeping
Sleeping
| # enhanced_dataset_interface.py | |
| """ | |
| Enhanced Dataset Interface Controller. | |
| Provides the complete interface logic for enhanced dataset mode including | |
| dataset selection, editing, creation, and verification workflows. | |
| Requirements: 2.1, 2.2, 2.7 | |
| """ | |
| import gradio as gr | |
| from typing import List, Dict, Tuple, Optional, Any, Union | |
| from datetime import datetime | |
| import uuid | |
| from src.core.verification_models import ( | |
| EnhancedVerificationSession, | |
| VerificationRecord, | |
| TestMessage, | |
| TestDataset, | |
| ) | |
| from src.core.enhanced_dataset_manager import EnhancedDatasetManager | |
| from src.core.verification_store import JSONVerificationStore | |
| from src.core.test_datasets import TestDatasetManager | |
| from src.interface.verification_ui import VerificationUIComponents | |
| from src.core.spiritual_monitor import SpiritualMonitor | |
| from src.core.ai_client import AIClientManager | |
| from src.core.enhanced_progress_tracker import EnhancedProgressTracker, VerificationMode | |
| from src.interface.enhanced_progress_components import ProgressTrackingMixin | |
| class EnhancedDatasetInterfaceController(ProgressTrackingMixin): | |
| """Controller for enhanced dataset mode interface.""" | |
| def __init__(self, store: JSONVerificationStore = None): | |
| """Initialize the enhanced dataset interface controller.""" | |
| super().__init__(VerificationMode.ENHANCED_DATASET) | |
| self.store = store or JSONVerificationStore() | |
| self.dataset_manager = EnhancedDatasetManager() | |
| self.ai_client_manager = AIClientManager() | |
| self.spiritual_monitor = SpiritualMonitor(self.ai_client_manager) | |
| self.current_session = None | |
| self.current_dataset = None | |
| self.current_message_index = 0 | |
| self.verification_start_time = None | |
| def initialize_interface(self) -> Tuple[List[str], str, str]: | |
| """ | |
| Initialize the enhanced dataset interface. | |
| Returns: | |
| Tuple of (dataset_choices, dataset_info, status_message) | |
| """ | |
| try: | |
| # Get all available datasets | |
| datasets = self.dataset_manager.list_datasets() | |
| # Create dropdown choices | |
| dataset_choices = [ | |
| f"{dataset.name} ({dataset.message_count} messages)" | |
| for dataset in datasets | |
| ] | |
| # Get templates for creation | |
| templates = self.dataset_manager.get_available_templates() | |
| return ( | |
| dataset_choices, | |
| "Select a dataset to view details and start verification or editing.", | |
| "β¨ Enhanced Dataset Mode initialized. Select a dataset to get started.", | |
| templates | |
| ) | |
| except Exception as e: | |
| return ( | |
| [], | |
| f"β Error loading datasets: {str(e)}", | |
| f"β Failed to initialize interface: {str(e)}", | |
| [] | |
| ) | |
| def get_dataset_info(self, dataset_selection: str) -> Tuple[str, Optional[TestDataset]]: | |
| """ | |
| Get dataset information for display. | |
| Args: | |
| dataset_selection: Selected dataset string from dropdown | |
| Returns: | |
| Tuple of (dataset_info_markdown, dataset_object) | |
| """ | |
| try: | |
| if not dataset_selection: | |
| return "Select a dataset to view details", None | |
| # Parse dataset name from selection | |
| dataset_name = dataset_selection.split(" (")[0] | |
| # Find matching dataset | |
| datasets = self.dataset_manager.list_datasets() | |
| selected_dataset = None | |
| for dataset in datasets: | |
| if dataset.name == dataset_name: | |
| selected_dataset = dataset | |
| break | |
| if not selected_dataset: | |
| return "β Dataset not found", None | |
| # Create info display | |
| info_markdown = f"""### {selected_dataset.name} | |
| **Description:** {selected_dataset.description} | |
| **Message Count:** {selected_dataset.message_count} messages | |
| **Dataset ID:** `{selected_dataset.dataset_id}` | |
| **Classification Breakdown:** | |
| """ | |
| # Add classification breakdown | |
| green_count = sum(1 for msg in selected_dataset.messages if msg.pre_classified_label.lower() == "green") | |
| yellow_count = sum(1 for msg in selected_dataset.messages if msg.pre_classified_label.lower() == "yellow") | |
| red_count = sum(1 for msg in selected_dataset.messages if msg.pre_classified_label.lower() == "red") | |
| info_markdown += f""" | |
| - π’ GREEN: {green_count} messages | |
| - π‘ YELLOW: {yellow_count} messages | |
| - π΄ RED: {red_count} messages | |
| """ | |
| return info_markdown, selected_dataset | |
| except Exception as e: | |
| return f"β Error loading dataset info: {str(e)}", None | |
| def render_test_cases_display(self, dataset: TestDataset) -> str: | |
| """ | |
| Render test cases for editing display. | |
| Args: | |
| dataset: Dataset to display test cases for | |
| Returns: | |
| HTML string for test cases display | |
| """ | |
| if not dataset or not dataset.messages: | |
| return "<p>No test cases in this dataset.</p>" | |
| html = """ | |
| <div style="font-family: system-ui; max-height: 400px; overflow-y: auto;"> | |
| """ | |
| for i, message in enumerate(dataset.messages): | |
| # Get classification badge | |
| badge_colors = {"green": "π’", "yellow": "π‘", "red": "π΄"} | |
| badge = badge_colors.get(message.pre_classified_label.lower(), "β") | |
| # Truncate message text for display | |
| display_text = message.text[:100] + "..." if len(message.text) > 100 else message.text | |
| html += f""" | |
| <div style="margin-bottom: 1em; padding: 1em; background-color: #f9fafb; border-radius: 6px; border: 1px solid #e5e7eb;"> | |
| <div style="display: flex; justify-content: space-between; align-items: center; margin-bottom: 0.5em;"> | |
| <h4 style="margin: 0; color: #1f2937;"> | |
| {badge} Test Case {i+1} | |
| </h4> | |
| <div> | |
| <button onclick="editTestCase('{message.message_id}')" | |
| style="background: #3b82f6; color: white; border: none; padding: 0.25em 0.5em; border-radius: 4px; cursor: pointer; margin-right: 0.5em;"> | |
| βοΈ Edit | |
| </button> | |
| <button onclick="deleteTestCase('{message.message_id}')" | |
| style="background: #dc2626; color: white; border: none; padding: 0.25em 0.5em; border-radius: 4px; cursor: pointer;"> | |
| ποΈ Delete | |
| </button> | |
| </div> | |
| </div> | |
| <div style="margin-bottom: 0.5em;"> | |
| <strong>Message:</strong> {display_text} | |
| </div> | |
| <div style="font-size: 0.875em; color: #6b7280;"> | |
| <strong>Expected Classification:</strong> {message.pre_classified_label.upper()} | |
| </div> | |
| <div style="font-size: 0.75em; color: #9ca3af; margin-top: 0.5em;"> | |
| ID: {message.message_id} | |
| </div> | |
| </div> | |
| """ | |
| html += """ | |
| </div> | |
| <script> | |
| function editTestCase(messageId) { | |
| // This would trigger the edit modal | |
| console.log('Edit test case:', messageId); | |
| } | |
| function deleteTestCase(messageId) { | |
| if (confirm('Are you sure you want to delete this test case?')) { | |
| console.log('Delete test case:', messageId); | |
| } | |
| } | |
| </script> | |
| """ | |
| return html | |
| def create_new_dataset( | |
| self, | |
| name: str, | |
| description: str, | |
| template_type: Optional[str] = None | |
| ) -> Tuple[bool, str, Optional[TestDataset]]: | |
| """ | |
| Create a new dataset. | |
| Args: | |
| name: Dataset name | |
| description: Dataset description | |
| template_type: Optional template type | |
| Returns: | |
| Tuple of (success, message, dataset) | |
| """ | |
| try: | |
| if not name or not name.strip(): | |
| return False, "β Dataset name is required", None | |
| if not description or not description.strip(): | |
| return False, "β Dataset description is required", None | |
| # Create dataset | |
| if template_type and template_type != "": | |
| dataset = self.dataset_manager.create_template_dataset(template_type) | |
| dataset.name = name.strip() | |
| dataset.description = description.strip() | |
| self.dataset_manager.update_dataset(dataset.dataset_id, dataset) | |
| else: | |
| dataset = self.dataset_manager.create_dataset(name.strip(), description.strip()) | |
| return True, f"β Dataset '{name}' created successfully", dataset | |
| except Exception as e: | |
| return False, f"β Error creating dataset: {str(e)}", None | |
| def add_test_case( | |
| self, | |
| dataset: TestDataset, | |
| message_text: str, | |
| classification: str | |
| ) -> Tuple[bool, str, TestDataset]: | |
| """ | |
| Add a new test case to the dataset. | |
| Args: | |
| dataset: Dataset to add test case to | |
| message_text: Message text | |
| classification: Expected classification | |
| Returns: | |
| Tuple of (success, message, updated_dataset) | |
| """ | |
| try: | |
| if not message_text or not message_text.strip(): | |
| return False, "β Message text is required", dataset | |
| if not classification: | |
| return False, "β Classification is required", dataset | |
| # Create new test message | |
| test_message = TestMessage( | |
| message_id=f"{dataset.dataset_id}_{uuid.uuid4().hex[:8]}", | |
| text=message_text.strip(), | |
| pre_classified_label=classification.lower() | |
| ) | |
| # Add to dataset | |
| self.dataset_manager.add_test_case(dataset.dataset_id, test_message) | |
| # Get updated dataset | |
| updated_dataset = self.dataset_manager.get_dataset(dataset.dataset_id) | |
| return True, f"β Test case added successfully", updated_dataset | |
| except Exception as e: | |
| return False, f"β Error adding test case: {str(e)}", dataset | |
| def save_dataset(self, dataset: TestDataset) -> Tuple[bool, str]: | |
| """ | |
| Save dataset changes. | |
| Args: | |
| dataset: Dataset to save | |
| Returns: | |
| Tuple of (success, message) | |
| """ | |
| try: | |
| # Validate dataset | |
| validation_errors = self.dataset_manager.validate_dataset(dataset) | |
| if validation_errors: | |
| error_list = "\n".join([f"β’ {error}" for error in validation_errors]) | |
| return False, f"β Validation errors:\n{error_list}" | |
| # Save dataset | |
| self.dataset_manager.update_dataset(dataset.dataset_id, dataset) | |
| return True, f"β Dataset '{dataset.name}' saved successfully" | |
| except Exception as e: | |
| return False, f"β Error saving dataset: {str(e)}" | |
| def start_verification_session( | |
| self, | |
| dataset: TestDataset, | |
| verifier_name: str | |
| ) -> Tuple[bool, str, Optional[EnhancedVerificationSession]]: | |
| """ | |
| Start a new verification session. | |
| Args: | |
| dataset: Dataset to verify | |
| verifier_name: Name of the verifier | |
| Returns: | |
| Tuple of (success, message, session) | |
| """ | |
| try: | |
| if not verifier_name or not verifier_name.strip(): | |
| return False, "β Verifier name is required", None | |
| if not dataset or not dataset.messages: | |
| return False, "β Dataset is empty or invalid", None | |
| # Create enhanced verification session | |
| session = EnhancedVerificationSession( | |
| session_id=f"enhanced_{uuid.uuid4().hex}", | |
| verifier_name=verifier_name.strip(), | |
| dataset_id=dataset.dataset_id, | |
| dataset_name=dataset.name, | |
| mode_type="enhanced_dataset", | |
| total_messages=len(dataset.messages), | |
| message_queue=[msg.message_id for msg in dataset.messages], | |
| mode_metadata={ | |
| "dataset_version": datetime.now().isoformat(), | |
| "original_message_count": len(dataset.messages) | |
| } | |
| ) | |
| # Save session | |
| self.store.save_session(session) | |
| self.current_session = session | |
| self.current_dataset = dataset | |
| self.current_message_index = 0 | |
| # Setup progress tracking | |
| self.setup_progress_tracking(len(dataset.messages)) | |
| return True, f"β Verification session started for '{dataset.name}'", session | |
| except Exception as e: | |
| return False, f"β Error starting verification: {str(e)}", None | |
| def get_current_message_for_verification(self) -> Tuple[Optional[TestMessage], Dict[str, Any]]: | |
| """ | |
| Get the current message for verification. | |
| Returns: | |
| Tuple of (test_message, classification_results) | |
| """ | |
| try: | |
| if not self.current_session or not self.current_dataset: | |
| return None, {} | |
| if self.current_message_index >= len(self.current_dataset.messages): | |
| return None, {} | |
| # Get current message | |
| current_message = self.current_dataset.messages[self.current_message_index] | |
| # Record verification start time for progress tracking | |
| self.verification_start_time = datetime.now() | |
| # Get spiritual distress classification | |
| assessment = self.spiritual_monitor.classify(current_message.text) | |
| # Convert to expected format | |
| classification_result = { | |
| "decision": assessment.state.value, | |
| "confidence": assessment.confidence, | |
| "indicators": assessment.indicators | |
| } | |
| return current_message, classification_result | |
| except Exception as e: | |
| return None, {"error": str(e)} | |
| def submit_verification_feedback( | |
| self, | |
| is_correct: bool, | |
| correction: Optional[str] = None, | |
| notes: str = "" | |
| ) -> Tuple[bool, str, Dict[str, Any]]: | |
| """ | |
| Submit verification feedback for current message. | |
| Args: | |
| is_correct: Whether the classification is correct | |
| correction: Correct classification if incorrect | |
| notes: Optional notes | |
| Returns: | |
| Tuple of (success, message, session_stats) | |
| """ | |
| try: | |
| if not self.current_session or not self.current_dataset: | |
| return False, "β No active verification session", {} | |
| current_message = self.current_dataset.messages[self.current_message_index] | |
| # Get classification result | |
| _, classification_result = self.get_current_message_for_verification() | |
| # Create verification record | |
| # Ensure valid classification values (green, yellow, red only) | |
| classifier_decision = classification_result.get("decision", "green") | |
| if classifier_decision not in ["green", "yellow", "red"]: | |
| classifier_decision = "green" # Safe fallback | |
| ground_truth = correction.lower() if correction else current_message.pre_classified_label | |
| if ground_truth not in ["green", "yellow", "red"]: | |
| ground_truth = "green" # Safe fallback | |
| record = VerificationRecord( | |
| message_id=current_message.message_id, | |
| original_message=current_message.text, | |
| classifier_decision=classifier_decision, | |
| classifier_confidence=classification_result.get("confidence", 0.0), | |
| classifier_indicators=classification_result.get("indicators", []), | |
| ground_truth_label=ground_truth, | |
| verifier_notes=notes, | |
| is_correct=is_correct | |
| ) | |
| # Add to session | |
| self.current_session.verifications.append(record) | |
| self.current_session.verified_count += 1 | |
| self.current_session.verified_message_ids.append(current_message.message_id) | |
| if is_correct: | |
| self.current_session.correct_count += 1 | |
| else: | |
| self.current_session.incorrect_count += 1 | |
| # Record verification with timing for progress tracking | |
| self.record_verification_with_timing(is_correct, self.verification_start_time) | |
| # Move to next message | |
| self.current_message_index += 1 | |
| self.current_session.current_queue_index = self.current_message_index | |
| # Check if session is complete | |
| if self.current_message_index >= len(self.current_dataset.messages): | |
| self.current_session.is_complete = True | |
| self.current_session.completed_at = datetime.now() | |
| # Save session | |
| self.store.save_session(self.current_session) | |
| # Calculate session stats | |
| session_stats = { | |
| "processed": self.current_session.verified_count, | |
| "total": self.current_session.total_messages, | |
| "correct": self.current_session.correct_count, | |
| "incorrect": self.current_session.incorrect_count, | |
| "accuracy": (self.current_session.correct_count / self.current_session.verified_count * 100) if self.current_session.verified_count > 0 else 0, | |
| "is_complete": self.current_session.is_complete | |
| } | |
| success_msg = "β Feedback recorded" | |
| if self.current_session.is_complete: | |
| success_msg += f" - Session complete! Final accuracy: {session_stats['accuracy']:.1f}%" | |
| return True, success_msg, session_stats | |
| except Exception as e: | |
| return False, f"β Error submitting feedback: {str(e)}", {} | |
| def export_session_results(self, format_type: str) -> Tuple[bool, str, Optional[str]]: | |
| """ | |
| Export session results in specified format. | |
| Args: | |
| format_type: Export format ("csv", "json", "xlsx") | |
| Returns: | |
| Tuple of (success, message, file_path) | |
| """ | |
| try: | |
| if not self.current_session: | |
| return False, "β No active session to export", None | |
| if format_type == "csv": | |
| file_content = self.store.export_to_csv(self.current_session.session_id) | |
| file_path = f"session_{self.current_session.session_id}.csv" | |
| elif format_type == "json": | |
| file_content = self.store.export_to_json(self.current_session.session_id) | |
| file_path = f"session_{self.current_session.session_id}.json" | |
| elif format_type == "xlsx": | |
| file_content = self.store.export_to_xlsx(self.current_session.session_id) | |
| file_path = f"session_{self.current_session.session_id}.xlsx" | |
| else: | |
| return False, f"β Unsupported export format: {format_type}", None | |
| return True, f"β Results exported to {format_type.upper()}", file_path | |
| except Exception as e: | |
| return False, f"β Error exporting results: {str(e)}", None | |
| def get_enhanced_progress_info(self) -> Dict[str, Any]: | |
| """ | |
| Get enhanced progress information for display. | |
| Returns: | |
| Dictionary containing progress information | |
| """ | |
| if not hasattr(self, 'progress_tracker') or not self.progress_tracker: | |
| return { | |
| "progress_display": "π Progress: Ready to start", | |
| "accuracy_display": "π― Current Accuracy: No verifications yet", | |
| "time_display": "β±οΈ Time: Not started", | |
| "error_display": "", | |
| "stats_summary": "No active session" | |
| } | |
| return { | |
| "progress_display": self.progress_tracker.get_progress_display(), | |
| "accuracy_display": self.progress_tracker.get_accuracy_display(), | |
| "time_display": self.progress_tracker.get_time_tracking_display(), | |
| "error_display": self.progress_tracker.get_error_display(), | |
| "stats_summary": self._get_session_stats_summary() | |
| } | |
| def record_verification_error(self, error_message: str, can_continue: bool = True) -> None: | |
| """ | |
| Record a verification error. | |
| Args: | |
| error_message: Description of the error | |
| can_continue: Whether processing can continue | |
| """ | |
| if hasattr(self, 'progress_tracker') and self.progress_tracker: | |
| self.progress_tracker.record_error(error_message, can_continue) | |
| def pause_verification_session(self) -> Tuple[bool, bool, bool]: | |
| """ | |
| Pause the current verification session. | |
| Returns: | |
| Tuple of control button visibility states | |
| """ | |
| if hasattr(self, 'progress_tracker') and self.progress_tracker: | |
| return self.handle_session_pause() | |
| return False, False, True | |
| def resume_verification_session(self) -> Tuple[bool, bool, bool]: | |
| """ | |
| Resume the current verification session. | |
| Returns: | |
| Tuple of control button visibility states | |
| """ | |
| if hasattr(self, 'progress_tracker') and self.progress_tracker: | |
| return self.handle_session_resume() | |
| return True, False, True | |
| def _get_session_stats_summary(self) -> str: | |
| """Get formatted session statistics summary.""" | |
| if not self.current_session: | |
| return "No active session" | |
| accuracy = (self.current_session.correct_count / self.current_session.verified_count * 100) if self.current_session.verified_count > 0 else 0 | |
| return f""" | |
| **Session Progress:** | |
| - Dataset: {self.current_session.dataset_name} | |
| - Processed: {self.current_session.verified_count}/{self.current_session.total_messages} | |
| - Accuracy: {accuracy:.1f}% | |
| - Correct: {self.current_session.correct_count} | |
| - Incorrect: {self.current_session.incorrect_count} | |
| """ |