Spaces:
Sleeping
Sleeping
| """ | |
| 🚀 Universal Prompt Optimizer - Enhanced Production UI v8.0 | |
| Principal Engineer Edition: Linear/Vercel-style Dark Mode with Premium UX | |
| """ | |
| import sys | |
| import os | |
| from pathlib import Path | |
| # Add src directory to Python path for Hugging Face Spaces | |
| # This ensures gepa_optimizer can be imported even if -e . installation fails | |
| src_path = Path(__file__).parent / "src" | |
| if src_path.exists() and str(src_path) not in sys.path: | |
| sys.path.insert(0, str(src_path)) | |
| import gradio as gr | |
| import json | |
| import base64 | |
| import io | |
| import os | |
| import logging | |
| import traceback | |
| import html | |
| import numpy as np | |
| from PIL import Image as PILImage | |
| from typing import List, Dict, Optional, Any, Tuple | |
| import threading | |
| from collections import deque | |
| # Optional import for URL image downloads | |
| try: | |
| import requests | |
| REQUESTS_AVAILABLE = True | |
| except ImportError: | |
| REQUESTS_AVAILABLE = False | |
| # ========================================== | |
| # 0. LOGGING & BACKEND UTILS | |
| # ========================================== | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format="%(asctime)s - %(levelname)s - %(message)s" | |
| ) | |
| logger = logging.getLogger(__name__) | |
| # Global Candidates Store (Thread-safe) | |
| _candidates_store = { | |
| 'candidates': deque(maxlen=100), | |
| 'lock': threading.Lock(), | |
| 'iteration': 0 | |
| } | |
| def add_candidate_to_store(candidate: Dict[str, Any]): | |
| with _candidates_store['lock']: | |
| _candidates_store['candidates'].append({ | |
| 'iteration': _candidates_store['iteration'], | |
| 'source': candidate.get('source', 'unknown'), | |
| 'prompt': candidate.get('prompt', ''), | |
| 'timestamp': candidate.get('timestamp', ''), | |
| 'index': len(_candidates_store['candidates']) + 1 | |
| }) | |
| def get_candidates_from_store() -> List[Dict[str, Any]]: | |
| with _candidates_store['lock']: | |
| return list(_candidates_store['candidates']) | |
| def clear_candidates_store(): | |
| with _candidates_store['lock']: | |
| _candidates_store['candidates'].clear() | |
| _candidates_store['iteration'] = 0 | |
| def increment_iteration(): | |
| with _candidates_store['lock']: | |
| _candidates_store['iteration'] += 1 | |
| # ========================================== | |
| # 1. MOCK BACKEND (Kept as provided) | |
| # ========================================== | |
| try: | |
| from gepa_optimizer import quick_optimize_sync, OptimizedResult | |
| BACKEND_AVAILABLE = True | |
| logger.info("✅ Successfully imported gepa_optimizer") | |
| except ImportError as e: | |
| BACKEND_AVAILABLE = False | |
| logger.error(f"❌ Failed to import gepa_optimizer: {str(e)}") | |
| logger.error(f"Python path: {sys.path}") | |
| logger.error(f"Current directory: {os.getcwd()}") | |
| logger.error(f"src directory exists: {os.path.exists(os.path.join(os.path.dirname(__file__), 'src'))}") | |
| from dataclasses import dataclass | |
| class OptimizedResult: | |
| optimized_prompt: str | |
| improvement_metrics: dict | |
| iteration_history: list | |
| def quick_optimize_sync(seed_prompt, dataset, model, **kwargs): | |
| import time | |
| iterations = kwargs.get('max_iterations', 5) | |
| batch_size = kwargs.get('batch_size', 4) | |
| use_llego = kwargs.get('use_llego', True) | |
| # Simulate processing time based on iterations | |
| time.sleep(0.5 * iterations) | |
| llego_note = "with LLEGO crossover" if use_llego else "standard mutation only" | |
| return OptimizedResult( | |
| optimized_prompt=f"""# OPTIMIZED PROMPT FOR {model} | |
| # ---------------------------------------- | |
| # Optimization: {iterations} iterations, batch size {batch_size}, {llego_note} | |
| ## Task Context | |
| {seed_prompt} | |
| ## Refined Instructions | |
| 1. Analyse the input constraints strictly. | |
| 2. Verify output format against expected schema. | |
| 3. Apply chain-of-thought reasoning before answering. | |
| 4. Cross-reference with provided examples for consistency. | |
| ## Safety & Edge Cases | |
| - If input is ambiguous, ask for clarification. | |
| - Maintain a professional, neutral tone. | |
| - Handle edge cases gracefully with informative responses.""", | |
| improvement_metrics={ | |
| "baseline_score": 0.45, | |
| "final_score": 0.92, | |
| "improvement": "+104.4%", | |
| "iterations_run": iterations, | |
| "candidates_evaluated": iterations * batch_size, | |
| }, | |
| iteration_history=[ | |
| f"Iter 1: Baseline evaluation - Score: 0.45", | |
| f"Iter 2: Added Chain-of-Thought constraints - Score: 0.62", | |
| f"Iter 3: Refined output formatting rules - Score: 0.78", | |
| f"Iter 4: {'LLEGO crossover applied' if use_llego else 'Mutation applied'} - Score: 0.88", | |
| f"Iter 5: Final refinement - Score: 0.92", | |
| ][:iterations], | |
| ) | |
| # ========================================== | |
| # 2. HELPER FUNCTIONS | |
| # ========================================== | |
| def gradio_image_to_base64(image_input) -> Optional[str]: | |
| """Convert Gradio image input to base64 string with comprehensive error handling.""" | |
| if image_input is None: | |
| return None | |
| try: | |
| pil_image = None | |
| if isinstance(image_input, np.ndarray): | |
| try: | |
| # Validate array shape and dtype | |
| if image_input.size == 0: | |
| logger.warning("Empty image array provided") | |
| return None | |
| pil_image = PILImage.fromarray(image_input) | |
| except (ValueError, TypeError) as e: | |
| logger.error(f"Failed to convert numpy array to PIL Image: {str(e)}") | |
| return None | |
| elif isinstance(image_input, PILImage.Image): | |
| pil_image = image_input | |
| elif isinstance(image_input, str): | |
| if not os.path.exists(image_input): | |
| logger.warning(f"Image file not found: {image_input}") | |
| return None | |
| try: | |
| pil_image = PILImage.open(image_input) | |
| except (IOError, OSError) as e: | |
| logger.error(f"Failed to open image file: {str(e)}") | |
| return None | |
| else: | |
| logger.warning(f"Unsupported image input type: {type(image_input)}") | |
| return None | |
| if pil_image is None: | |
| return None | |
| # Convert image to RGB mode if necessary (some formats like RGBA, P, etc. need conversion) | |
| try: | |
| # Convert to RGB if image has transparency or is in a mode that might cause issues | |
| if pil_image.mode in ('RGBA', 'LA', 'P'): | |
| # Create a white background for transparent images | |
| rgb_image = PILImage.new('RGB', pil_image.size, (255, 255, 255)) | |
| if pil_image.mode == 'P': | |
| pil_image = pil_image.convert('RGBA') | |
| rgb_image.paste(pil_image, mask=pil_image.split()[-1] if pil_image.mode in ('RGBA', 'LA') else None) | |
| pil_image = rgb_image | |
| elif pil_image.mode != 'RGB': | |
| # Convert other modes to RGB | |
| pil_image = pil_image.convert('RGB') | |
| except Exception as convert_error: | |
| logger.warning(f"Image mode conversion failed, trying to continue: {str(convert_error)}") | |
| # Try to convert anyway | |
| try: | |
| pil_image = pil_image.convert('RGB') | |
| except Exception: | |
| pass | |
| try: | |
| buffered = io.BytesIO() | |
| # Save as PNG (universal format) - PIL will handle conversion from any format | |
| # PNG supports all color modes and is widely compatible | |
| pil_image.save(buffered, format="PNG") | |
| img_str = base64.b64encode(buffered.getvalue()).decode("utf-8") | |
| return f"data:image/png;base64,{img_str}" | |
| except (IOError, OSError, ValueError) as e: | |
| logger.error(f"Failed to encode image to base64: {str(e)}") | |
| return None | |
| except Exception as e: | |
| logger.error(f"Unexpected error in image conversion: {str(e)}\n{traceback.format_exc()}") | |
| return None | |
| def validate_dataset(dataset: List[Dict]) -> Tuple[bool, str]: | |
| """Validate dataset structure and content with detailed error messages.""" | |
| if not isinstance(dataset, list): | |
| return False, "Dataset must be a list of examples." | |
| if len(dataset) == 0: | |
| return False, "Dataset is empty. Add at least one example." | |
| # Validate each item in the dataset | |
| for i, item in enumerate(dataset): | |
| if not isinstance(item, dict): | |
| return False, f"Dataset item {i+1} must be a dictionary with 'input' and 'output' keys." | |
| if "input" not in item or "output" not in item: | |
| return False, f"Dataset item {i+1} is missing required 'input' or 'output' field." | |
| if not isinstance(item.get("input"), str) or not isinstance(item.get("output"), str): | |
| return False, f"Dataset item {i+1} has invalid 'input' or 'output' type (must be strings)." | |
| if not item.get("input", "").strip() or not item.get("output", "").strip(): | |
| return False, f"Dataset item {i+1} has empty 'input' or 'output' field." | |
| return True, "" | |
| def validate_model(model: str, custom_model: str) -> Tuple[bool, str]: | |
| """Validate model selection and custom model format.""" | |
| if not model: | |
| return False, "Please select a foundation model." | |
| if model == "custom": | |
| if not custom_model or not custom_model.strip(): | |
| return False, "Custom model selected but no model ID provided." | |
| # Validate custom model format (provider/model_name) | |
| parts = custom_model.strip().split("/") | |
| if len(parts) != 2: | |
| return False, "Custom model ID must be in format 'provider/model_name' (e.g., 'openai/gpt-4')." | |
| if not parts[0].strip() or not parts[1].strip(): | |
| return False, "Custom model ID provider and model name cannot be empty." | |
| return True, "" | |
| def validate_api_keys(model: str, api_keys: Dict[str, str]) -> Tuple[bool, str]: | |
| """Validate that required API keys are provided for the selected model.""" | |
| if not api_keys: | |
| return True, "" # Keys are optional if already set in environment | |
| model_provider = model.split("/")[0] if "/" in model else model.lower() | |
| # Check if model requires a specific provider key | |
| required_providers = { | |
| "openai": "openai", | |
| "anthropic": "anthropic", | |
| "google": "google" | |
| } | |
| if model_provider in required_providers: | |
| provider = required_providers[model_provider] | |
| key_value = api_keys.get(provider, "").strip() if api_keys.get(provider) else "" | |
| # Check environment variable as fallback | |
| env_vars = { | |
| "openai": "OPENAI_API_KEY", | |
| "anthropic": "ANTHROPIC_API_KEY", | |
| "google": "GOOGLE_API_KEY" | |
| } | |
| if not key_value and not os.environ.get(env_vars.get(provider, "")): | |
| return False, f"API key for {provider.capitalize()} is required for model '{model}' but not provided." | |
| return True, "" | |
| def safe_optimize(seed_prompt, dataset, model, custom_model="", max_iterations=5, max_metric_calls=50, batch_size=4, use_llego=True, api_keys=None): | |
| """Safely run optimization with comprehensive error handling.""" | |
| try: | |
| # Log backend status | |
| if not BACKEND_AVAILABLE: | |
| logger.warning("⚠️ Backend not available - using mock optimizer. Check gepa_optimizer installation.") | |
| else: | |
| logger.info("✅ Backend available - using real gepa_optimizer") | |
| # Validate seed prompt | |
| if not seed_prompt or not isinstance(seed_prompt, str): | |
| return False, "Seed prompt is required and must be a string.", None | |
| if not seed_prompt.strip(): | |
| return False, "Seed prompt cannot be empty.", None | |
| # Validate dataset | |
| is_valid, msg = validate_dataset(dataset) | |
| if not is_valid: | |
| return False, msg, None | |
| # Determine final model | |
| final_model = custom_model.strip() if custom_model and custom_model.strip() else model | |
| # Validate model | |
| model_valid, model_msg = validate_model(model, custom_model) | |
| if not model_valid: | |
| return False, model_msg, None | |
| # Validate API keys | |
| api_valid, api_msg = validate_api_keys(final_model, api_keys or {}) | |
| if not api_valid: | |
| return False, api_msg, None | |
| # Validate optimization parameters | |
| if not isinstance(max_iterations, int) or max_iterations < 1 or max_iterations > 50: | |
| return False, "Max iterations must be between 1 and 50.", None | |
| if not isinstance(max_metric_calls, int) or max_metric_calls < 10 or max_metric_calls > 500: | |
| return False, "Max metric calls must be between 10 and 500.", None | |
| if not isinstance(batch_size, int) or batch_size < 1 or batch_size > 20: | |
| return False, "Batch size must be between 1 and 20.", None | |
| # Check backend availability | |
| if not BACKEND_AVAILABLE: | |
| logger.warning("Backend not available, using mock optimizer") | |
| # Set API keys from UI if provided | |
| if api_keys: | |
| try: | |
| key_mapping = { | |
| "openai": "OPENAI_API_KEY", | |
| "google": "GOOGLE_API_KEY", | |
| "anthropic": "ANTHROPIC_API_KEY", | |
| } | |
| for provider, env_var in key_mapping.items(): | |
| if api_keys.get(provider) and api_keys[provider].strip(): | |
| os.environ[env_var] = api_keys[provider].strip() | |
| logger.info(f"Set {provider} API key from UI") | |
| except Exception as e: | |
| logger.error(f"Failed to set API keys: {str(e)}") | |
| return False, f"Failed to configure API keys: {str(e)}", None | |
| # Run optimization | |
| try: | |
| # Check GEPA version for debugging | |
| if BACKEND_AVAILABLE: | |
| try: | |
| import gepa | |
| logger.info(f"📦 GEPA library version: {getattr(gepa, '__version__', 'unknown')}") | |
| except Exception as e: | |
| logger.warning(f"Could not check GEPA version: {e}") | |
| logger.info(f"🚀 Starting optimization with model: {final_model}") | |
| logger.info(f" Parameters: iterations={max_iterations}, metric_calls={max_metric_calls}, batch={batch_size}, llego={use_llego}") | |
| logger.info(f" Dataset size: {len(dataset)} examples") | |
| logger.info(f" 🔍 GEPA should call: evaluate(capture_traces=True) → make_reflective_dataset() → propose_new_texts()") | |
| result = quick_optimize_sync( | |
| seed_prompt=seed_prompt, | |
| dataset=dataset, | |
| model=final_model, | |
| max_iterations=max_iterations, | |
| max_metric_calls=max_metric_calls, | |
| batch_size=batch_size, | |
| use_llego=use_llego, | |
| verbose=True, | |
| ) | |
| # Log result details for debugging | |
| logger.info(f"📊 Optimization result received:") | |
| logger.info(f" Type: {type(result)}") | |
| logger.info(f" Has prompt: {hasattr(result, 'prompt')}") | |
| logger.info(f" Has optimized_prompt: {hasattr(result, 'optimized_prompt')}") | |
| if hasattr(result, 'improvement_data'): | |
| logger.info(f" improvement_data: {result.improvement_data}") | |
| if hasattr(result, 'total_iterations'): | |
| logger.info(f" total_iterations: {result.total_iterations}") | |
| if hasattr(result, 'optimization_time'): | |
| logger.info(f" optimization_time: {result.optimization_time}") | |
| if hasattr(result, 'status'): | |
| logger.info(f" status: {result.status}") | |
| if hasattr(result, 'error_message') and result.error_message: | |
| logger.error(f" error_message: {result.error_message}") | |
| # Validate result structure | |
| if not result: | |
| return False, "Optimization returned no result.", None | |
| # Check for both property-based (real backend) and attribute-based (mock backend) | |
| has_prompt = False | |
| try: | |
| # Real backend uses .prompt property | |
| if hasattr(result, 'prompt'): | |
| _ = result.prompt # Try to access property | |
| has_prompt = True | |
| # Mock backend uses .optimized_prompt attribute | |
| elif hasattr(result, 'optimized_prompt'): | |
| has_prompt = True | |
| except Exception as e: | |
| logger.warning(f"Error checking result structure: {str(e)}") | |
| if not has_prompt: | |
| return False, "Optimization result is missing required prompt field.", None | |
| return True, "Success", result | |
| except KeyboardInterrupt: | |
| logger.warning("Optimization interrupted by user") | |
| return False, "Optimization was interrupted.", None | |
| except TimeoutError: | |
| logger.error("Optimization timed out") | |
| return False, "Optimization timed out. Try reducing max_iterations or max_metric_calls.", None | |
| except ConnectionError as e: | |
| logger.error(f"Connection error during optimization: {str(e)}") | |
| return False, f"Connection error: {str(e)}. Check your internet connection and API keys.", None | |
| except ValueError as e: | |
| logger.error(f"Invalid parameter in optimization: {str(e)}") | |
| return False, f"Invalid configuration: {str(e)}", None | |
| except Exception as e: | |
| error_msg = str(e) | |
| logger.error(f"Optimization failed: {error_msg}\n{traceback.format_exc()}") | |
| # Provide user-friendly error messages | |
| if "api" in error_msg.lower() or "key" in error_msg.lower(): | |
| return False, f"API error: {error_msg}. Please check your API keys.", None | |
| elif "rate limit" in error_msg.lower(): | |
| return False, "Rate limit exceeded. Please wait a moment and try again.", None | |
| elif "quota" in error_msg.lower(): | |
| return False, "API quota exceeded. Please check your account limits.", None | |
| else: | |
| return False, f"Optimization failed: {error_msg}", None | |
| except Exception as e: | |
| logger.error(f"Unexpected error in safe_optimize: {str(e)}\n{traceback.format_exc()}") | |
| return False, f"Unexpected error: {str(e)}", None | |
| # ========================================== | |
| # 3. UI LOGIC | |
| # ========================================== | |
| def add_example(input_text, output_text, image_input, current_dataset): | |
| """Add an example to the dataset with comprehensive error handling.""" | |
| try: | |
| # Validate inputs | |
| if not input_text: | |
| raise gr.Error("Input text is required.") | |
| if not output_text: | |
| raise gr.Error("Output text is required.") | |
| if not isinstance(input_text, str) or not isinstance(output_text, str): | |
| raise gr.Error("Input and Output must be text strings.") | |
| input_text = input_text.strip() | |
| output_text = output_text.strip() | |
| if not input_text: | |
| raise gr.Error("Input text cannot be empty.") | |
| if not output_text: | |
| raise gr.Error("Output text cannot be empty.") | |
| # Validate dataset state | |
| if not isinstance(current_dataset, list): | |
| raise gr.Error("Dataset state is invalid. Please refresh the page.") | |
| # Process image with error handling | |
| img_b64 = None | |
| try: | |
| img_b64 = gradio_image_to_base64(image_input) | |
| except Exception as e: | |
| logger.warning(f"Image processing failed, continuing without image: {str(e)}") | |
| # Continue without image - it's optional | |
| # Create new item | |
| try: | |
| new_item = { | |
| "input": input_text, | |
| "output": output_text, | |
| "image": img_b64, | |
| "image_preview": "🖼️ Image" if img_b64 else "-" | |
| } | |
| # Validate item structure | |
| if not isinstance(new_item["input"], str) or not isinstance(new_item["output"], str): | |
| raise gr.Error("Failed to create dataset item: invalid data types.") | |
| current_dataset.append(new_item) | |
| return current_dataset, "", "", None | |
| except Exception as e: | |
| logger.error(f"Failed to add example to dataset: {str(e)}") | |
| raise gr.Error(f"Failed to add example: {str(e)}") | |
| except gr.Error: | |
| # Re-raise Gradio errors as-is | |
| raise | |
| except Exception as e: | |
| logger.error(f"Unexpected error in add_example: {str(e)}\n{traceback.format_exc()}") | |
| raise gr.Error(f"Unexpected error: {str(e)}") | |
| def update_table(dataset): | |
| """Update the dataset table display with error handling.""" | |
| try: | |
| if not dataset: | |
| return [] | |
| if not isinstance(dataset, list): | |
| logger.error(f"Invalid dataset type: {type(dataset)}") | |
| return [] | |
| table_data = [] | |
| for i, item in enumerate(dataset): | |
| try: | |
| if not isinstance(item, dict): | |
| logger.warning(f"Skipping invalid dataset item {i+1}: not a dictionary") | |
| continue | |
| input_text = str(item.get("input", ""))[:50] if item.get("input") else "" | |
| output_text = str(item.get("output", ""))[:50] if item.get("output") else "" | |
| image_preview = str(item.get("image_preview", "-")) | |
| table_data.append([i+1, input_text, output_text, image_preview]) | |
| except Exception as e: | |
| logger.warning(f"Error processing dataset item {i+1}: {str(e)}") | |
| continue | |
| return table_data | |
| except Exception as e: | |
| logger.error(f"Error updating table: {str(e)}\n{traceback.format_exc()}") | |
| return [] | |
| def clear_dataset(): | |
| """Clear the dataset with error handling.""" | |
| try: | |
| return [], [] | |
| except Exception as e: | |
| logger.error(f"Error clearing dataset: {str(e)}") | |
| return [], [] | |
| def get_candidates_display(): | |
| """Generate HTML display for candidates with error handling.""" | |
| try: | |
| candidates = get_candidates_from_store() | |
| if not candidates: | |
| return "<div style='padding: 2rem; text-align: center; color: #6b7280;'><div style='font-size: 3rem; opacity: 0.3; margin-bottom: 1rem;'>🧬</div><p>Waiting for optimization to start...</p></div>" | |
| if not isinstance(candidates, list): | |
| logger.error(f"Invalid candidates type: {type(candidates)}") | |
| return "<div style='padding: 2rem; text-align: center; color: #ef4444;'>Error loading candidates.</div>" | |
| html_output = "<div style='display: flex; flex-direction: column; gap: 12px;'>" | |
| # Show last 10 candidates | |
| candidates_to_show = list(candidates)[-10:] | |
| for c in reversed(candidates_to_show): | |
| try: | |
| if not isinstance(c, dict): | |
| continue | |
| iteration = str(c.get('iteration', '?')) | |
| source = str(c.get('source', 'unknown')).upper() | |
| prompt = str(c.get('prompt', ''))[:200] | |
| # Escape HTML to prevent XSS | |
| iteration = html.escape(iteration) | |
| source = html.escape(source) | |
| prompt = html.escape(prompt) | |
| html_output += f""" | |
| <div style='background: linear-gradient(135deg, #0f172a 0%, #1e293b 100%); border: 1px solid #334155; border-radius: 8px; padding: 16px; position: relative; overflow: hidden;'> | |
| <div style='position: absolute; top: 0; left: 0; width: 100%; height: 2px; background: linear-gradient(90deg, #06b6d4, #3b82f6);'></div> | |
| <div style='display: flex; justify-content: space-between; align-items: center; margin-bottom: 8px;'> | |
| <span style='font-family: "JetBrains Mono", monospace; font-size: 0.75rem; color: #06b6d4; font-weight: 600;'>ITERATION {iteration}</span> | |
| <span style='background: #1e293b; border: 1px solid #334155; padding: 2px 8px; border-radius: 4px; font-size: 0.7rem; color: #94a3b8;'>{source}</span> | |
| </div> | |
| <div style='font-family: "JetBrains Mono", monospace; font-size: 0.85rem; color: #cbd5e1; line-height: 1.6;'>{prompt}...</div> | |
| </div> | |
| """ | |
| except Exception as e: | |
| logger.warning(f"Error rendering candidate: {str(e)}") | |
| continue | |
| html_output += "</div>" | |
| return html_output | |
| except Exception as e: | |
| logger.error(f"Error generating candidates display: {str(e)}\n{traceback.format_exc()}") | |
| return "<div style='padding: 2rem; text-align: center; color: #ef4444;'>Error loading candidates display.</div>" | |
| def run_optimization_flow(seed, dataset, model, custom_model, iter_count, call_count, batch, llego, k_openai, k_google, k_anthropic, progress=gr.Progress()): | |
| """Run the optimization flow with comprehensive error handling.""" | |
| import time | |
| try: | |
| # Validate inputs | |
| if not seed: | |
| raise gr.Error("Seed prompt is required.") | |
| if not dataset: | |
| raise gr.Error("Dataset is required. Add at least one example.") | |
| if not model: | |
| raise gr.Error("Model selection is required.") | |
| # Validate numeric parameters | |
| try: | |
| iter_count = int(iter_count) if iter_count else 5 | |
| call_count = int(call_count) if call_count else 50 | |
| batch = int(batch) if batch else 4 | |
| except (ValueError, TypeError) as e: | |
| raise gr.Error(f"Invalid optimization parameters: {str(e)}") | |
| # Determine final model | |
| try: | |
| final_model = custom_model.strip() if custom_model and custom_model.strip() else model | |
| except Exception as e: | |
| logger.warning(f"Error processing custom model: {str(e)}") | |
| final_model = model | |
| # Clear candidates store | |
| try: | |
| clear_candidates_store() | |
| except Exception as e: | |
| logger.warning(f"Error clearing candidates store: {str(e)}") | |
| # Prepare API keys | |
| api_keys = {} | |
| try: | |
| api_keys = { | |
| "openai": k_openai if k_openai else "", | |
| "google": k_google if k_google else "", | |
| "anthropic": k_anthropic if k_anthropic else "" | |
| } | |
| except Exception as e: | |
| logger.warning(f"Error processing API keys: {str(e)}") | |
| # Initial state | |
| try: | |
| yield ( | |
| gr.update(visible=True), | |
| gr.update(visible=False), | |
| gr.update(visible=False), | |
| "🚀 Initializing Genetic Algorithm...", | |
| "", {}, "", "" | |
| ) | |
| time.sleep(0.5) # Brief pause for UI update | |
| except Exception as e: | |
| logger.error(f"Error in initial UI update: {str(e)}") | |
| raise gr.Error(f"Failed to initialize UI: {str(e)}") | |
| # Evolution loop (visual progress - actual work happens in safe_optimize) | |
| try: | |
| for i in range(1, iter_count + 1): | |
| try: | |
| increment_iteration() | |
| add_candidate_to_store({ | |
| "source": "evolution_step", | |
| "prompt": f"Candidate {i}: Optimizing instruction clarity and task alignment...", | |
| "timestamp": "now" | |
| }) | |
| progress(i/iter_count, desc=f"Evolution Round {i}/{iter_count}") | |
| yield ( | |
| gr.update(), gr.update(), gr.update(), | |
| f"🧬 **Evolution Round {i}/{iter_count}**\n\n• Generating {batch} prompt mutations\n• Evaluating fitness scores\n• Selecting top candidates", | |
| "", {}, "", get_candidates_display() | |
| ) | |
| time.sleep(0.3) # Pause to show progress | |
| except Exception as e: | |
| logger.warning(f"Error in evolution step {i}: {str(e)}") | |
| # Continue with next iteration | |
| continue | |
| except Exception as e: | |
| logger.error(f"Error in evolution loop: {str(e)}") | |
| # Continue to optimization attempt | |
| # Final optimization | |
| try: | |
| success, msg, result = safe_optimize( | |
| seed_prompt=seed, | |
| dataset=dataset, | |
| model=model, | |
| custom_model=custom_model, | |
| max_iterations=iter_count, | |
| max_metric_calls=call_count, | |
| batch_size=batch, | |
| use_llego=llego, | |
| api_keys=api_keys | |
| ) | |
| if not success: | |
| # Show error state | |
| yield ( | |
| gr.update(visible=True), | |
| gr.update(visible=False), | |
| gr.update(visible=False), | |
| f"❌ **Optimization Failed**\n\n{msg}", | |
| "", {}, "", get_candidates_display() | |
| ) | |
| raise gr.Error(msg) | |
| # Validate result before displaying | |
| if not result: | |
| raise gr.Error("Optimization completed but returned no result.") | |
| # Check for both property-based (real backend) and attribute-based (mock backend) | |
| # Try to access the prompt to see if it exists (works for both attributes and properties) | |
| has_optimized_prompt = False | |
| try: | |
| if hasattr(result, 'optimized_prompt'): | |
| # Mock backend - direct attribute | |
| has_optimized_prompt = True | |
| elif hasattr(result, 'prompt'): | |
| # Real backend - property-based, try to access it | |
| _ = result.prompt | |
| has_optimized_prompt = True | |
| elif hasattr(result, '_result') and hasattr(result._result, 'optimized_prompt'): | |
| has_optimized_prompt = True | |
| except Exception: | |
| pass | |
| if not has_optimized_prompt: | |
| raise gr.Error("Optimization result is missing required fields.") | |
| # Show results | |
| try: | |
| # Handle both property-based (real backend) and attribute-based (mock backend) | |
| if hasattr(result, 'prompt'): | |
| # Real backend - use .prompt property | |
| try: | |
| optimized_prompt = result.prompt or "" | |
| except Exception as e: | |
| logger.error(f"Error accessing result.prompt: {str(e)}") | |
| optimized_prompt = "" | |
| # Get improvement_data (real backend) | |
| improvement_data = result.improvement_data if hasattr(result, 'improvement_data') else {} | |
| # Convert improvement_data to display format | |
| # Real backend uses: baseline_val_score, optimized_val_score, relative_improvement_percent | |
| if isinstance(improvement_data, dict): | |
| # Try real backend field names first, then fall back to alternatives | |
| baseline_score = ( | |
| improvement_data.get("baseline_val_score") or | |
| improvement_data.get("baseline_score") or | |
| improvement_data.get("baseline_metrics", {}).get("composite_score", 0.0) | |
| ) | |
| final_score = ( | |
| improvement_data.get("optimized_val_score") or | |
| improvement_data.get("final_score") or | |
| improvement_data.get("final_metrics", {}).get("composite_score", 0.0) | |
| ) | |
| improvement_percent = ( | |
| improvement_data.get("relative_improvement_percent") or | |
| improvement_data.get("improvement_percent") or | |
| "N/A" | |
| ) | |
| # Format improvement percent | |
| if isinstance(improvement_percent, (int, float)): | |
| improvement_percent = f"+{improvement_percent:.1f}%" if improvement_percent > 0 else f"{improvement_percent:.1f}%" | |
| improvement_metrics = { | |
| "baseline_score": round(baseline_score, 4) if isinstance(baseline_score, (int, float)) else baseline_score, | |
| "final_score": round(final_score, 4) if isinstance(final_score, (int, float)) else final_score, | |
| "improvement": improvement_percent, | |
| "iterations_run": result.total_iterations if hasattr(result, 'total_iterations') else improvement_data.get("iterations", 0), | |
| "optimization_time": f"{result.optimization_time:.2f}s" if hasattr(result, 'optimization_time') else "N/A", | |
| } | |
| # Log the improvement data for debugging | |
| logger.info(f"📊 Improvement data received: {improvement_data}") | |
| logger.info(f"📊 Formatted metrics: {improvement_metrics}") | |
| else: | |
| improvement_metrics = {} | |
| logger.warning(f"⚠️ improvement_data is not a dict: {type(improvement_data)}") | |
| # Create iteration history from reflection_history if available | |
| iteration_history = [] | |
| if hasattr(result, '_result') and hasattr(result._result, 'reflection_history'): | |
| reflection_history = result._result.reflection_history | |
| for i, reflection in enumerate(reflection_history, 1): | |
| summary = reflection.get('summary', f'Iteration {i}') | |
| iteration_history.append(f"Iter {i}: {summary}") | |
| elif isinstance(improvement_data, dict) and 'iteration_history' in improvement_data: | |
| iteration_history = improvement_data['iteration_history'] | |
| else: | |
| # Fallback: create simple history | |
| iterations = result.total_iterations if hasattr(result, 'total_iterations') else 0 | |
| iteration_history = [f"Iteration {i+1} completed" for i in range(iterations)] | |
| elif hasattr(result, 'optimized_prompt'): | |
| # Mock backend - direct attribute | |
| optimized_prompt = result.optimized_prompt or "" | |
| improvement_metrics = getattr(result, 'improvement_metrics', {}) | |
| iteration_history = getattr(result, 'iteration_history', []) | |
| else: | |
| optimized_prompt = "" | |
| improvement_metrics = {} | |
| iteration_history = [] | |
| history_text = "\n".join(iteration_history) if isinstance(iteration_history, list) else str(iteration_history) | |
| yield ( | |
| gr.update(visible=False), | |
| gr.update(visible=False), | |
| gr.update(visible=True), | |
| "✅ Optimization Complete", | |
| optimized_prompt, | |
| improvement_metrics, | |
| history_text, | |
| get_candidates_display() | |
| ) | |
| except Exception as e: | |
| logger.error(f"Error displaying results: {str(e)}") | |
| raise gr.Error(f"Failed to display results: {str(e)}") | |
| except gr.Error: | |
| # Re-raise Gradio errors | |
| raise | |
| except Exception as e: | |
| logger.error(f"Error in optimization: {str(e)}\n{traceback.format_exc()}") | |
| raise gr.Error(f"Optimization error: {str(e)}") | |
| except gr.Error: | |
| # Re-raise Gradio errors as-is | |
| raise | |
| except KeyboardInterrupt: | |
| logger.warning("Optimization interrupted by user") | |
| raise gr.Error("Optimization was interrupted.") | |
| except Exception as e: | |
| logger.error(f"Unexpected error in optimization flow: {str(e)}\n{traceback.format_exc()}") | |
| raise gr.Error(f"Unexpected error: {str(e)}") | |
| # ========================================== | |
| # 4. ENHANCED CSS (Linear/Vercel-style) | |
| # ========================================== | |
| CUSTOM_CSS = """ | |
| @import url('https://fonts.googleapis.com/css2?family=Inter:wght@300;400;500;600;700;800&family=JetBrains+Mono:wght@400;500;600&display=swap'); | |
| :root { | |
| --bg0: #070A0F; | |
| --bg1: #0B1020; | |
| --bg2: rgba(255,255,255,0.04); | |
| --bg3: rgba(255,255,255,0.06); | |
| --stroke0: rgba(148,163,184,0.14); | |
| --stroke1: rgba(148,163,184,0.22); | |
| --text0: #EAF0FF; | |
| --text1: rgba(234,240,255,0.74); | |
| --text2: rgba(234,240,255,0.56); | |
| --teal: #06B6D4; | |
| --blue: #3B82F6; | |
| --ok: #10B981; | |
| --okGlow: rgba(16,185,129,0.18); | |
| --bad: #EF4444; | |
| --shadow: 0 12px 40px rgba(0,0,0,0.45); | |
| --shadowSoft: 0 10px 24px rgba(0,0,0,0.32); | |
| --radius: 14px; | |
| --radiusSm: 10px; | |
| } | |
| html, body { | |
| background: radial-gradient(1200px 700px at 20% -10%, rgba(6,182,212,0.13), transparent 55%), | |
| radial-gradient(1000px 650px at 90% 0%, rgba(59,130,246,0.10), transparent 60%), | |
| linear-gradient(180deg, var(--bg0) 0%, var(--bg1) 100%); | |
| color: var(--text0); | |
| font-family: Inter, system-ui, -apple-system, Segoe UI, Roboto, sans-serif; | |
| } | |
| .gradio-container { | |
| max-width: 1520px !important; | |
| padding: 12px 18px !important; | |
| margin: 0 auto !important; | |
| } | |
| /* --- App shell --- */ | |
| .app-shell { min-height: auto !important; } | |
| .topbar { | |
| padding: 12px 14px 12px 14px; | |
| margin-bottom: 4px; | |
| border: 1px solid var(--stroke0); | |
| border-radius: var(--radius); | |
| background: linear-gradient(180deg, rgba(255,255,255,0.04) 0%, rgba(255,255,255,0.02) 100%); | |
| box-shadow: var(--shadowSoft); | |
| } | |
| .topbar-wrap { margin-bottom: 0 !important; } | |
| .brand-row { display: flex; align-items: center; justify-content: space-between; gap: 16px; } | |
| .brand-left { display: flex; align-items: center; gap: 14px; } | |
| .brand-mark { | |
| width: 44px; height: 44px; border-radius: 12px; | |
| background: linear-gradient(135deg, rgba(6,182,212,0.26), rgba(59,130,246,0.20)); | |
| border: 1px solid rgba(6,182,212,0.30); | |
| box-shadow: 0 0 0 4px rgba(6,182,212,0.10); | |
| display: flex; align-items: center; justify-content: center; | |
| font-weight: 800; | |
| } | |
| .h1 { | |
| font-size: 22px; font-weight: 800; letter-spacing: -0.02em; | |
| margin: 0; line-height: 1.2; | |
| } | |
| .subtitle { margin-top: 4px; color: var(--text1); font-weight: 500; font-size: 13px; } | |
| .status-pill { | |
| display: inline-flex; align-items: center; gap: 10px; | |
| padding: 10px 12px; border-radius: 999px; | |
| background: rgba(255,255,255,0.03); | |
| border: 1px solid var(--stroke0); | |
| color: var(--text1); | |
| font-size: 12px; font-weight: 700; letter-spacing: 0.08em; | |
| text-transform: uppercase; | |
| } | |
| .dot { | |
| width: 10px; height: 10px; border-radius: 999px; | |
| background: var(--ok); | |
| box-shadow: 0 0 16px rgba(16,185,129,0.40); | |
| animation: pulse 1.8s ease-in-out infinite; | |
| } | |
| @keyframes pulse { 0%, 100% { transform: scale(1); opacity: 0.95; } 50% { transform: scale(1.18); opacity: 0.70; } } | |
| /* --- Two-column layout helpers --- */ | |
| .left-col, .right-col { min-width: 280px; } | |
| /* --- Cards / Sections --- */ | |
| .card { | |
| border-radius: var(--radius); | |
| background: linear-gradient(180deg, rgba(255,255,255,0.045) 0%, rgba(255,255,255,0.022) 100%); | |
| border: 1px solid var(--stroke0); | |
| box-shadow: var(--shadowSoft); | |
| padding: 16px; | |
| } | |
| .card + .card { margin-top: 14px; } | |
| .card-head { | |
| display: flex; align-items: center; justify-content: space-between; | |
| gap: 12px; | |
| padding-bottom: 12px; | |
| margin-bottom: 12px; | |
| border-bottom: 1px solid var(--stroke0); | |
| } | |
| .card-title { | |
| display: flex; align-items: center; gap: 10px; | |
| font-size: 13px; font-weight: 800; letter-spacing: 0.12em; | |
| text-transform: uppercase; color: var(--text1); | |
| } | |
| .step { | |
| width: 30px; height: 30px; border-radius: 10px; | |
| background: linear-gradient(135deg, rgba(6,182,212,0.95), rgba(59,130,246,0.95)); | |
| box-shadow: 0 10px 20px rgba(6,182,212,0.18); | |
| display: flex; align-items: center; justify-content: center; | |
| color: white; font-weight: 900; font-size: 13px; | |
| } | |
| .hint { color: var(--text2); font-size: 12px; line-height: 1.4; } | |
| .ds-count span { | |
| display: inline-flex; | |
| align-items: center; | |
| padding: 7px 10px; | |
| border-radius: 999px; | |
| border: 1px solid var(--stroke0); | |
| background: rgba(255,255,255,0.02); | |
| color: var(--text1) !important; | |
| font-weight: 700; | |
| font-size: 12px; | |
| } | |
| /* --- Inputs --- */ | |
| label { color: var(--text1) !important; font-weight: 650 !important; font-size: 12px !important; } | |
| textarea, input, select { | |
| background: rgba(255,255,255,0.03) !important; | |
| border: 1px solid var(--stroke0) !important; | |
| border-radius: 12px !important; | |
| color: var(--text0) !important; | |
| transition: border-color 0.15s ease, box-shadow 0.15s ease, transform 0.15s ease; | |
| } | |
| textarea:focus, input:focus, select:focus { | |
| outline: none !important; | |
| border-color: rgba(6,182,212,0.55) !important; | |
| box-shadow: 0 0 0 4px rgba(6,182,212,0.14) !important; | |
| } | |
| .keybox input { font-family: "JetBrains Mono", ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, monospace !important; } | |
| .seed textarea { min-height: 160px !important; } | |
| .mono textarea { font-family: "JetBrains Mono", ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, monospace !important; font-size: 12.5px !important; } | |
| /* --- Buttons --- */ | |
| .cta button { | |
| width: 100% !important; | |
| border: 0 !important; | |
| border-radius: 14px !important; | |
| padding: 14px 16px !important; | |
| font-size: 13px !important; | |
| font-weight: 900 !important; | |
| letter-spacing: 0.12em !important; | |
| text-transform: uppercase !important; | |
| color: white !important; | |
| background: linear-gradient(135deg, rgba(6,182,212,1) 0%, rgba(59,130,246,1) 100%) !important; | |
| box-shadow: 0 18px 48px rgba(6,182,212,0.22) !important; | |
| position: relative !important; | |
| overflow: hidden !important; | |
| } | |
| .cta button::after { | |
| content: ""; | |
| position: absolute; inset: -120px; | |
| background: radial-gradient(closest-side, rgba(255,255,255,0.18), transparent 60%); | |
| transform: translateX(-40%); | |
| transition: transform 0.45s ease; | |
| } | |
| .cta button:hover { transform: translateY(-1px); } | |
| .cta button:hover::after { transform: translateX(40%); } | |
| .cta button:active { transform: translateY(0px); } | |
| .btn-secondary button { | |
| border-radius: 12px !important; | |
| border: 1px solid var(--stroke1) !important; | |
| background: rgba(255,255,255,0.03) !important; | |
| color: var(--text0) !important; | |
| font-weight: 800 !important; | |
| } | |
| .btn-secondary button:hover { border-color: rgba(6,182,212,0.55) !important; } | |
| .btn-danger button { | |
| border-radius: 12px !important; | |
| border: 1px solid rgba(239,68,68,0.55) !important; | |
| background: rgba(239,68,68,0.06) !important; | |
| color: rgba(255,170,170,1) !important; | |
| font-weight: 900 !important; | |
| } | |
| /* --- Dataframe --- */ | |
| .dataframe { | |
| border-radius: 14px !important; | |
| border: 1px solid var(--stroke0) !important; | |
| background: rgba(255,255,255,0.02) !important; | |
| overflow: hidden !important; | |
| } | |
| .dataframe thead th { | |
| background: rgba(255,255,255,0.04) !important; | |
| color: var(--text1) !important; | |
| font-weight: 900 !important; | |
| font-size: 11px !important; | |
| letter-spacing: 0.10em !important; | |
| text-transform: uppercase !important; | |
| border-bottom: 1px solid var(--stroke0) !important; | |
| } | |
| .dataframe tbody td { | |
| color: var(--text0) !important; | |
| font-size: 12px !important; | |
| border-bottom: 1px solid rgba(148,163,184,0.10) !important; | |
| } | |
| .dataframe tbody tr:hover { background: rgba(255,255,255,0.03) !important; } | |
| /* --- Status / Results --- */ | |
| .panel { | |
| border-radius: var(--radius); | |
| border: 1px solid var(--stroke0); | |
| background: linear-gradient(180deg, rgba(255,255,255,0.045), rgba(255,255,255,0.020)); | |
| box-shadow: var(--shadowSoft); | |
| padding: 16px; | |
| } | |
| .panel-title { | |
| display: flex; align-items: center; justify-content: space-between; | |
| gap: 10px; | |
| padding-bottom: 12px; margin-bottom: 12px; | |
| border-bottom: 1px solid var(--stroke0); | |
| } | |
| .panel-title h3 { margin: 0; font-size: 13px; letter-spacing: 0.12em; text-transform: uppercase; color: var(--text1); } | |
| .running-pill { | |
| display: inline-flex; align-items: center; gap: 10px; | |
| padding: 8px 10px; border-radius: 999px; | |
| border: 1px solid rgba(6,182,212,0.38); | |
| background: rgba(6,182,212,0.08); | |
| color: rgba(153,246,228,0.95); | |
| font-weight: 900; font-size: 11px; letter-spacing: 0.10em; text-transform: uppercase; | |
| } | |
| .running-dot { width: 9px; height: 9px; border-radius: 99px; background: var(--teal); box-shadow: 0 0 18px rgba(6,182,212,0.45); animation: pulse 1.8s ease-in-out infinite; } | |
| .empty { | |
| border-radius: var(--radius); | |
| border: 1px dashed rgba(148,163,184,0.26); | |
| background: rgba(255,255,255,0.02); | |
| padding: 28px; | |
| text-align: center; | |
| color: var(--text2); | |
| } | |
| .empty .big { font-size: 40px; opacity: 0.22; margin-bottom: 10px; } | |
| .empty .t { color: var(--text1); font-weight: 800; margin-bottom: 6px; } | |
| .empty .s { font-size: 12px; } | |
| .results { | |
| border-radius: var(--radius); | |
| border: 1px solid rgba(16,185,129,0.55); | |
| background: linear-gradient(180deg, rgba(16,185,129,0.12), rgba(255,255,255,0.02)); | |
| box-shadow: 0 0 0 4px rgba(16,185,129,0.10), 0 20px 60px rgba(0,0,0,0.42); | |
| padding: 16px; | |
| } | |
| .results-banner { | |
| display: flex; align-items: center; justify-content: space-between; | |
| gap: 12px; | |
| padding-bottom: 12px; margin-bottom: 12px; | |
| border-bottom: 1px solid rgba(16,185,129,0.28); | |
| } | |
| .results-banner .k { display: flex; align-items: center; gap: 10px; } | |
| .results-banner .k .icon { | |
| width: 36px; height: 36px; border-radius: 12px; | |
| background: rgba(16,185,129,0.18); | |
| border: 1px solid rgba(16,185,129,0.45); | |
| display: flex; align-items: center; justify-content: center; | |
| } | |
| .results-banner .k .title { font-weight: 900; color: rgba(189,255,225,0.98); letter-spacing: 0.06em; text-transform: uppercase; font-size: 12px; } | |
| .results-banner .k .sub { margin-top: 2px; color: rgba(189,255,225,0.70); font-size: 12px; } | |
| .tabs { background: transparent !important; } | |
| .tab-nav button { | |
| background: transparent !important; | |
| border: 0 !important; | |
| border-bottom: 2px solid transparent !important; | |
| color: var(--text2) !important; | |
| font-weight: 800 !important; | |
| padding: 10px 12px !important; | |
| } | |
| .tab-nav button[aria-selected="true"] { | |
| color: rgba(153,246,228,0.98) !important; | |
| border-bottom-color: rgba(6,182,212,0.75) !important; | |
| } | |
| .tab-nav button:hover { color: var(--text0) !important; } | |
| .small-note { color: var(--text2); font-size: 12px; } | |
| /* --- Candidates stream --- */ | |
| .cand-empty { padding: 28px; text-align: center; color: var(--text2); } | |
| .cand-empty-icon { font-size: 40px; opacity: 0.25; margin-bottom: 10px; } | |
| .cand-empty-title { color: var(--text1); font-weight: 900; margin-bottom: 4px; } | |
| .cand-empty-sub { font-size: 12px; } | |
| .cand-stream { display: flex; flex-direction: column; gap: 10px; } | |
| .cand-card { | |
| border-radius: 14px; | |
| border: 1px solid rgba(148,163,184,0.18); | |
| background: linear-gradient(135deg, rgba(15,23,42,0.85), rgba(2,6,23,0.45)); | |
| overflow: hidden; | |
| } | |
| .cand-topbar { height: 2px; background: linear-gradient(90deg, var(--teal), var(--blue)); } | |
| .cand-header { | |
| display: flex; align-items: center; justify-content: space-between; | |
| gap: 10px; | |
| padding: 10px 12px 0 12px; | |
| } | |
| .cand-iter { font-family: "JetBrains Mono", ui-monospace; font-size: 11px; color: rgba(153,246,228,0.92); font-weight: 800; letter-spacing: 0.08em; } | |
| .cand-pill { | |
| font-size: 10px; font-weight: 900; letter-spacing: 0.10em; | |
| padding: 5px 8px; border-radius: 999px; | |
| border: 1px solid rgba(148,163,184,0.20); | |
| background: rgba(255,255,255,0.03); | |
| color: var(--text2); | |
| } | |
| .cand-body { | |
| padding: 10px 12px 12px 12px; | |
| font-family: "JetBrains Mono", ui-monospace; | |
| font-size: 12px; | |
| line-height: 1.6; | |
| color: rgba(234,240,255,0.75); | |
| } | |
| /* --- Responsive --- */ | |
| @media (max-width: 980px) { | |
| .gradio-container { padding: 16px 12px !important; } | |
| .brand-row { flex-direction: column; align-items: flex-start; } | |
| .status-pill { align-self: stretch; justify-content: center; } | |
| } | |
| """ | |
| FORCE_DARK_JS = """ | |
| function forceDarkTheme() { | |
| try { | |
| const url = new URL(window.location.href); | |
| if (url.searchParams.get("__theme") !== "dark") { | |
| url.searchParams.set("__theme", "dark"); | |
| window.location.replace(url.toString()); | |
| } | |
| } catch (e) { | |
| // no-op | |
| } | |
| } | |
| forceDarkTheme(); | |
| """ | |
| # ========================================== | |
| # 5. UI CONSTRUCTION (Redesigned) | |
| # ========================================== | |
| APP_TITLE = "Universal Prompt Optimizer" | |
| APP_SUBTITLE = "Genetic Evolutionary Prompt Agent (GEPA)" | |
| STATUS_READY = "System Ready" | |
| with gr.Blocks( | |
| title="Universal Prompt Optimizer", | |
| theme=gr.themes.Base() | |
| ) as app: | |
| dataset_state = gr.State([]) | |
| # TOP BAR | |
| gr.HTML( | |
| f""" | |
| <div class="topbar"> | |
| <div class="brand-row"> | |
| <div class="brand-left"> | |
| <div class="brand-mark">GE</div> | |
| <div> | |
| <div class="h1">{APP_TITLE}</div> | |
| <div class="subtitle">{APP_SUBTITLE}</div> | |
| </div> | |
| </div> | |
| <div class="status-pill"><span class="dot"></span> {STATUS_READY}</div> | |
| </div> | |
| </div> | |
| """, | |
| elem_classes=["topbar-wrap"] | |
| ) | |
| # MAIN LAYOUT | |
| with gr.Row(): | |
| # LEFT COLUMN: Configuration | |
| with gr.Column(scale=5): | |
| # Step 1 | |
| with gr.Group(elem_classes=["card"]): | |
| gr.HTML( | |
| """ | |
| <div class="card-head"> | |
| <div class="card-title"><div class="step">1</div> Model & Credentials</div> | |
| <div class="hint">Select a target model, then provide keys (stored in-session only).</div> | |
| </div> | |
| """ | |
| ) | |
| with gr.Row(): | |
| model_select = gr.Dropdown( | |
| label="Foundation Model", | |
| choices=[ | |
| "openai/gpt-4o", | |
| "openai/gpt-4-turbo", | |
| "anthropic/claude-3-5-sonnet", | |
| "google/gemini-1.5-pro", | |
| "custom" | |
| ], | |
| value="openai/gpt-4o", | |
| scale=2 | |
| ) | |
| custom_model_input = gr.Textbox( | |
| label="Custom Model ID", | |
| placeholder="provider/model_name", | |
| scale=1 | |
| ) | |
| gr.HTML('<div class="subsection-title">API Access Keys</div>') | |
| gr.Markdown("*Keys are stored in-session only and never logged*", elem_classes=["text-xs"]) | |
| with gr.Row(): | |
| key_openai = gr.Textbox( | |
| label="OpenAI API Key", | |
| type="password", | |
| placeholder="sk-...", | |
| scale=1 | |
| ) | |
| key_google = gr.Textbox( | |
| label="Google API Key", | |
| type="password", | |
| placeholder="AIza...", | |
| scale=1 | |
| ) | |
| key_anthropic = gr.Textbox( | |
| label="Anthropic API Key", | |
| type="password", | |
| placeholder="sk-ant...", | |
| scale=1 | |
| ) | |
| # Step 2 | |
| with gr.Group(elem_classes=["card"]): | |
| gr.HTML( | |
| """ | |
| <div class="card-head"> | |
| <div class="card-title"><div class="step">2</div> Seed Prompt</div> | |
| <div class="hint">Describe the task, constraints, output format, and tone.</div> | |
| </div> | |
| """ | |
| ) | |
| seed_input = gr.Textbox( | |
| label="Task Description", | |
| placeholder="Example: You are a code reviewer that identifies security vulnerabilities in Python code. Return a JSON report with severity and fixes...", | |
| lines=7, | |
| max_lines=14, | |
| elem_classes=["seed", "mono"] | |
| ) | |
| # Step 3 | |
| with gr.Group(elem_classes=["card"]): | |
| gr.HTML( | |
| """ | |
| <div class="card-head"> | |
| <div class="card-title"><div class="step">3</div> Training Examples</div> | |
| <div class="hint">Add a few high-quality I/O pairs (images optional) to shape the optimizer.</div> | |
| </div> | |
| """ | |
| ) | |
| with gr.Tabs(): | |
| with gr.Tab("Manual Entry"): | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| d_in = gr.Textbox( | |
| label="Input / User Prompt", | |
| placeholder="Example user input...", | |
| lines=3 | |
| ) | |
| d_out = gr.Textbox( | |
| label="Ideal Output", | |
| placeholder="Expected AI response...", | |
| lines=3 | |
| ) | |
| with gr.Column(scale=1): | |
| d_img = gr.Image( | |
| label="Attach Image (Optional)", | |
| type="numpy", | |
| height=170 | |
| ) | |
| btn_add = gr.Button( | |
| "Add Example", | |
| elem_classes=["btn-secondary"] | |
| ) | |
| with gr.Tab("Bulk Import (JSON)"): | |
| gr.Markdown( | |
| "Paste a JSON array like: `[{\"input\": \"...\", \"output\": \"...\"}]`<br>" | |
| "**Images**: Upload images below and reference them in JSON using:<br>" | |
| "• `\"image_name\": \"filename.png\"` - Match by filename (recommended)<br>" | |
| "• `\"image_index\": 0` - Reference by upload order (0-based)<br>" | |
| "• `\"image\": \"data:image/...\"` - Include base64 directly", | |
| elem_classes=["small-note"] | |
| ) | |
| bulk_json = gr.Textbox( | |
| show_label=False, | |
| placeholder='[{"input": "...", "output": "...", "image_index": 0}]', | |
| lines=6 | |
| ) | |
| bulk_images = gr.File( | |
| label="Upload Images (Optional) - All formats supported (PNG, JPG, JPEG, GIF, WEBP, BMP, TIFF, etc.)", | |
| file_count="multiple", | |
| file_types=[".png", ".jpg", ".jpeg", ".gif", ".webp", ".bmp", ".tiff", ".tif", ".svg", ".ico", ".heic", ".heif"], | |
| height=100 | |
| ) | |
| btn_import = gr.Button( | |
| "Import JSON", | |
| elem_classes=["btn-secondary"] | |
| ) | |
| with gr.Row(): | |
| gr.HTML("<div class='hint'>Current dataset</div>") | |
| ds_count = gr.HTML( | |
| "<span style='color: var(--text-secondary);'>0 examples loaded</span>", | |
| elem_classes=["ds-count"] | |
| ) | |
| ds_table = gr.Dataframe( | |
| headers=["ID", "Input", "Output", "Media"], | |
| datatype=["number", "str", "str", "str"], | |
| row_count=6, | |
| column_count=(4, "fixed"), | |
| interactive=False | |
| ) | |
| with gr.Row(): | |
| btn_clear = gr.Button( | |
| "Clear All", | |
| elem_classes=["btn-danger"], | |
| size="sm" | |
| ) | |
| # Step 4 (Prominent, not buried) | |
| with gr.Group(elem_classes=["card"]): | |
| gr.HTML( | |
| """ | |
| <div class="card-head"> | |
| <div class="card-title"><div class="step">4</div> Optimization Controls</div> | |
| <div class="hint">Tune evolution budget. Defaults are safe for quick runs.</div> | |
| </div> | |
| """ | |
| ) | |
| with gr.Row(): | |
| slider_iter = gr.Slider( | |
| minimum=1, | |
| maximum=20, | |
| value=5, | |
| step=1, | |
| label="Evolution Rounds", | |
| info="Number of genetic iterations" | |
| ) | |
| slider_calls = gr.Slider( | |
| minimum=10, | |
| maximum=200, | |
| value=50, | |
| step=10, | |
| label="Max LLM Calls", | |
| info="Total API call budget" | |
| ) | |
| with gr.Row(): | |
| slider_batch = gr.Slider( | |
| minimum=1, | |
| maximum=10, | |
| value=4, | |
| step=1, | |
| label="Batch Size", | |
| info="Candidates per iteration" | |
| ) | |
| check_llego = gr.Checkbox( | |
| value=True, | |
| label="Enable LLEGO Crossover", | |
| info="Use advanced genetic operations" | |
| ) | |
| btn_optimize = gr.Button( | |
| "Start Optimization", | |
| elem_classes=["cta", "mt-6"] | |
| ) | |
| # RIGHT: STATUS + RESULTS | |
| with gr.Column(scale=5, elem_classes=["right-col"]): | |
| # STATUS PANEL (Hidden by default) | |
| status_panel = gr.Group(visible=False, elem_classes=["panel"]) | |
| with status_panel: | |
| gr.HTML( | |
| """ | |
| <div class="panel-title"> | |
| <h3>Optimization status</h3> | |
| <div class="running-pill"><span class="running-dot"></span> Running</div> | |
| </div> | |
| """ | |
| ) | |
| txt_status = gr.Markdown("Initializing genetic algorithm...") | |
| # EMPTY STATE | |
| empty_state = gr.HTML( | |
| """ | |
| <div class="empty"> | |
| <div class="big">🧬</div> | |
| <div class="t">Ready to optimize</div> | |
| <div class="s">Fill Steps 1–3, then click <b>Start Optimization</b> to begin prompt evolution.</div> | |
| </div> | |
| """, | |
| visible=True | |
| ) | |
| # RESULTS PANEL (Hidden by default) | |
| results_panel = gr.Group(visible=False, elem_classes=["results"]) | |
| with results_panel: | |
| gr.HTML( | |
| """ | |
| <div class="results-banner"> | |
| <div class="k"> | |
| <div class="icon">✓</div> | |
| <div> | |
| <div class="title">Optimization successful</div> | |
| <div class="sub">Review the optimized prompt, metrics, and evolution traces.</div> | |
| </div> | |
| </div> | |
| </div> | |
| """ | |
| ) | |
| with gr.Tabs(): | |
| with gr.Tab("Optimized Prompt"): | |
| res_prompt = gr.Textbox( | |
| label="Optimized Prompt", | |
| lines=18, | |
| max_lines=28, | |
| interactive=False, | |
| show_label=True, | |
| elem_classes=["mono"] | |
| ) | |
| with gr.Tab("Metrics & Log"): | |
| res_metrics = gr.JSON(label="Performance Gains") | |
| res_history = gr.TextArea( | |
| label="Evolution Log", | |
| interactive=False, | |
| lines=10 | |
| ) | |
| with gr.Tab("🧬 Live Candidates"): | |
| gr.Markdown("Real-time stream of generated prompt candidates during optimization:") | |
| live_candidates = gr.HTML() | |
| btn_refresh_cand = gr.Button( | |
| "🔄 Refresh Stream", | |
| elem_classes=["secondary-btn"], | |
| size="sm" | |
| ) | |
| # ========================================== | |
| # 6. EVENT HANDLERS | |
| # ========================================== | |
| # Dataset Management | |
| def update_dataset_count(dataset): | |
| """Update dataset count display with error handling.""" | |
| try: | |
| if not isinstance(dataset, list): | |
| return "<span style='color: var(--text-secondary);'>0 examples loaded</span>" | |
| count = len(dataset) | |
| return f"<span style='color: var(--text-secondary);'>{count} example{'s' if count != 1 else ''} loaded</span>" | |
| except Exception as e: | |
| logger.error(f"Error updating dataset count: {str(e)}") | |
| return "<span style='color: var(--text-secondary);'>Error</span>" | |
| # Wrap event handlers with error handling | |
| def safe_add_example(*args): | |
| """Wrapper for add_example with error handling.""" | |
| try: | |
| return add_example(*args) | |
| except gr.Error: | |
| raise | |
| except Exception as e: | |
| logger.error(f"Unexpected error in add_example: {str(e)}") | |
| raise gr.Error(f"Failed to add example: {str(e)}") | |
| def safe_update_table(dataset): | |
| """Wrapper for update_table with error handling.""" | |
| try: | |
| return update_table(dataset) | |
| except Exception as e: | |
| logger.error(f"Error updating table: {str(e)}") | |
| return [] | |
| def safe_clear_dataset(): | |
| """Wrapper for clear_dataset with error handling.""" | |
| try: | |
| return clear_dataset() | |
| except Exception as e: | |
| logger.error(f"Error clearing dataset: {str(e)}") | |
| return [], [] | |
| btn_add.click( | |
| safe_add_example, | |
| inputs=[d_in, d_out, d_img, dataset_state], | |
| outputs=[dataset_state, d_in, d_out, d_img] | |
| ).then( | |
| safe_update_table, | |
| inputs=[dataset_state], | |
| outputs=[ds_table] | |
| ).then( | |
| update_dataset_count, | |
| inputs=[dataset_state], | |
| outputs=[ds_count] | |
| ) | |
| btn_clear.click( | |
| safe_clear_dataset, | |
| outputs=[dataset_state, ds_table] | |
| ).then( | |
| lambda: "<span style='color: var(--text-secondary);'>0 examples loaded</span>", | |
| outputs=[ds_count] | |
| ) | |
| # Bulk Import | |
| def import_bulk_json(json_text, current_dataset, uploaded_images): | |
| """Import examples from JSON with comprehensive error handling and image support.""" | |
| try: | |
| # Validate inputs | |
| if not json_text or not json_text.strip(): | |
| raise gr.Error("JSON input is empty. Please provide a JSON array.") | |
| if not isinstance(current_dataset, list): | |
| raise gr.Error("Dataset state is invalid. Please refresh the page.") | |
| # Parse JSON | |
| try: | |
| data = json.loads(json_text.strip()) | |
| except json.JSONDecodeError as e: | |
| raise gr.Error(f"Invalid JSON format: {str(e)}. Please check your JSON syntax.") | |
| # Validate structure | |
| if not isinstance(data, list): | |
| raise gr.Error("JSON must be an array of objects. Example: [{\"input\": \"...\", \"output\": \"...\"}]") | |
| if len(data) == 0: | |
| raise gr.Error("JSON array is empty. Add at least one example object.") | |
| # Process uploaded images into base64 format | |
| # Create both a list (for index-based access) and a dict (for filename-based access) | |
| image_list = [] | |
| image_dict = {} # Maps filename -> base64 | |
| original_filenames = [] # Track original filenames for error messages | |
| # Handle case where uploaded_images might be None, empty list, or single file | |
| if uploaded_images: | |
| # Ensure it's a list | |
| if not isinstance(uploaded_images, list): | |
| uploaded_images = [uploaded_images] | |
| logger.info(f"Processing {len(uploaded_images)} uploaded image(s)") | |
| for idx, img_file in enumerate(uploaded_images): | |
| try: | |
| if img_file is None: | |
| logger.warning(f"Image {idx} is None, skipping") | |
| continue | |
| # Extract filename and process image | |
| filename = None | |
| img_b64 = None | |
| file_path = None | |
| # Handle different file input formats (Gradio 6.1.0 returns file paths as strings) | |
| if isinstance(img_file, str): | |
| # File path (most common in Gradio 6.x) | |
| file_path = img_file | |
| if os.path.exists(file_path): | |
| filename = os.path.basename(file_path) | |
| img_b64 = gradio_image_to_base64(file_path) | |
| logger.info(f"Processed image from path: {filename}") | |
| else: | |
| logger.warning(f"File path does not exist: {file_path}") | |
| elif isinstance(img_file, dict): | |
| # Gradio file dict format: {"name": "...", "path": "...", "orig_name": "...", ...} | |
| file_path = img_file.get("path") or img_file.get("name") | |
| # Try to get original filename first, then fall back to path basename | |
| orig_name = img_file.get("orig_name") or img_file.get("name") | |
| if file_path: | |
| if orig_name: | |
| filename = os.path.basename(orig_name) | |
| else: | |
| filename = os.path.basename(file_path) | |
| img_b64 = gradio_image_to_base64(file_path) | |
| logger.info(f"Processed image from dict: {filename} (path: {file_path})") | |
| elif hasattr(img_file, 'name'): | |
| # File object with name attribute | |
| file_path = img_file.name if hasattr(img_file, 'name') else str(img_file) | |
| filename = os.path.basename(file_path) if file_path else None | |
| if file_path and os.path.exists(file_path): | |
| img_b64 = gradio_image_to_base64(file_path) | |
| logger.info(f"Processed image from file object: {filename}") | |
| else: | |
| # Try to process as image directly (numpy array, PIL Image, etc.) | |
| img_b64 = gradio_image_to_base64(img_file) | |
| if img_b64: | |
| filename = f"image_{len(image_list)}.png" | |
| logger.info(f"Processed image as direct input: {filename}") | |
| if img_b64: | |
| image_list.append(img_b64) | |
| # Store by filename (case-insensitive matching) | |
| if filename: | |
| original_filenames.append(filename) | |
| # Store with original filename | |
| image_dict[filename] = img_b64 | |
| # Also store with lowercase for case-insensitive lookup | |
| image_dict[filename.lower()] = img_b64 | |
| # Also store without extension for more flexible matching | |
| base_name = os.path.splitext(filename)[0] | |
| if base_name and base_name != filename: | |
| image_dict[base_name] = img_b64 | |
| image_dict[base_name.lower()] = img_b64 | |
| logger.info(f"Stored image: {filename} (keys: {filename}, {filename.lower()})") | |
| else: | |
| logger.warning(f"Image processed but no filename extracted, using index") | |
| image_dict[f"image_{len(image_list)-1}"] = img_b64 | |
| else: | |
| logger.warning(f"Failed to convert image {idx} to base64 (type: {type(img_file)})") | |
| except Exception as e: | |
| logger.error(f"Failed to process uploaded image {idx}: {str(e)}\n{traceback.format_exc()}") | |
| continue | |
| logger.info(f"Successfully processed {len(image_list)} images. Available filenames: {original_filenames}") | |
| # Validate and import items | |
| imported_count = 0 | |
| errors = [] | |
| for i, item in enumerate(data): | |
| try: | |
| if not isinstance(item, dict): | |
| errors.append(f"Item {i+1}: not a dictionary") | |
| continue | |
| if "input" not in item or "output" not in item: | |
| errors.append(f"Item {i+1}: missing 'input' or 'output' field") | |
| continue | |
| input_val = item["input"] | |
| output_val = item["output"] | |
| if not isinstance(input_val, str) or not isinstance(output_val, str): | |
| errors.append(f"Item {i+1}: 'input' and 'output' must be strings") | |
| continue | |
| if not input_val.strip() or not output_val.strip(): | |
| errors.append(f"Item {i+1}: 'input' and 'output' cannot be empty") | |
| continue | |
| # Handle image - check for image_name first, then image_index, then direct image field | |
| img_b64 = None | |
| if "image_name" in item: | |
| # Match uploaded image by filename | |
| image_name = item["image_name"] | |
| if not isinstance(image_name, str): | |
| errors.append(f"Item {i+1}: 'image_name' must be a string") | |
| continue | |
| if not image_name.strip(): | |
| errors.append(f"Item {i+1}: 'image_name' cannot be empty") | |
| continue | |
| # Try to find matching image (case-insensitive) | |
| image_name_clean = image_name.strip() | |
| logger.info(f"Item {i+1}: Looking for image '{image_name_clean}' in {len(image_dict)} stored images") | |
| # Try exact match first | |
| img_b64 = image_dict.get(image_name_clean) | |
| if not img_b64: | |
| # Try case-insensitive match | |
| img_b64 = image_dict.get(image_name_clean.lower()) | |
| if not img_b64: | |
| # Try matching just the filename without path | |
| basename = os.path.basename(image_name_clean) | |
| img_b64 = image_dict.get(basename) or image_dict.get(basename.lower()) | |
| if img_b64: | |
| logger.info(f"Item {i+1}: Matched image by basename '{basename}'") | |
| if not img_b64: | |
| # Try matching without extension | |
| base_name = os.path.splitext(image_name_clean)[0] | |
| if base_name: | |
| img_b64 = image_dict.get(base_name) or image_dict.get(base_name.lower()) | |
| if img_b64: | |
| logger.info(f"Item {i+1}: Matched image by base name '{base_name}'") | |
| if img_b64: | |
| logger.info(f"Item {i+1}: Successfully matched image '{image_name_clean}'") | |
| else: | |
| # Show available filenames for debugging | |
| available_str = ', '.join(original_filenames[:5]) | |
| if len(original_filenames) > 5: | |
| available_str += f" (and {len(original_filenames) - 5} more)" | |
| if not original_filenames: | |
| available_str = "none uploaded" | |
| # Log warning but continue - don't fail the entire import | |
| logger.warning(f"Item {i+1}: Image '{image_name_clean}' not found. Available images: {available_str}. Image dict keys: {list(image_dict.keys())[:10]}") | |
| elif "image_index" in item: | |
| # Reference uploaded image by index | |
| img_idx = item["image_index"] | |
| if not isinstance(img_idx, int): | |
| errors.append(f"Item {i+1}: 'image_index' must be an integer") | |
| continue | |
| if img_idx < 0 or img_idx >= len(image_list): | |
| errors.append(f"Item {i+1}: 'image_index' {img_idx} is out of range (0-{len(image_list)-1})") | |
| continue | |
| img_b64 = image_list[img_idx] | |
| elif "image" in item: | |
| # Direct base64 image in JSON | |
| img_b64 = item["image"] | |
| if img_b64 and not isinstance(img_b64, str): | |
| errors.append(f"Item {i+1}: 'image' must be a base64 string") | |
| continue | |
| # Add valid item | |
| current_dataset.append({ | |
| "input": input_val.strip(), | |
| "output": output_val.strip(), | |
| "image": img_b64, # Optional - can be None | |
| "image_preview": "🖼️ Image" if img_b64 else "-" | |
| }) | |
| imported_count += 1 | |
| except Exception as e: | |
| errors.append(f"Item {i+1}: {str(e)}") | |
| logger.warning(f"Error importing item {i+1}: {str(e)}") | |
| continue | |
| # Report results | |
| if imported_count == 0: | |
| error_msg = "No valid examples imported. " | |
| if errors: | |
| error_msg += "Errors: " + "; ".join(errors[:3]) | |
| if len(errors) > 3: | |
| error_msg += f" (and {len(errors) - 3} more)" | |
| raise gr.Error(error_msg) | |
| if errors: | |
| warning_msg = f"Imported {imported_count} example(s). " | |
| if len(errors) <= 3: | |
| warning_msg += f"Warnings: {'; '.join(errors)}" | |
| else: | |
| warning_msg += f"{len(errors)} items had errors." | |
| logger.warning(warning_msg) | |
| return current_dataset, "" | |
| except gr.Error: | |
| # Re-raise Gradio errors | |
| raise | |
| except Exception as e: | |
| logger.error(f"Unexpected error in import_bulk_json: {str(e)}\n{traceback.format_exc()}") | |
| raise gr.Error(f"Failed to import JSON: {str(e)}") | |
| btn_import.click( | |
| import_bulk_json, | |
| inputs=[bulk_json, dataset_state, bulk_images], | |
| outputs=[dataset_state, bulk_json] | |
| ).then( | |
| safe_update_table, | |
| inputs=[dataset_state], | |
| outputs=[ds_table] | |
| ).then( | |
| update_dataset_count, | |
| inputs=[dataset_state], | |
| outputs=[ds_count] | |
| ) | |
| # Main Optimization Flow | |
| btn_optimize.click( | |
| run_optimization_flow, | |
| inputs=[ | |
| seed_input, dataset_state, model_select, custom_model_input, | |
| slider_iter, slider_calls, slider_batch, check_llego, | |
| key_openai, key_google, key_anthropic | |
| ], | |
| outputs=[ | |
| status_panel, empty_state, results_panel, | |
| txt_status, res_prompt, res_metrics, res_history, live_candidates | |
| ] | |
| ) | |
| # Refresh Candidates | |
| def safe_get_candidates_display(): | |
| """Wrapper for get_candidates_display with error handling.""" | |
| try: | |
| return get_candidates_display() | |
| except Exception as e: | |
| logger.error(f"Error refreshing candidates: {str(e)}") | |
| return "<div style='padding: 2rem; text-align: center; color: #ef4444;'>Error loading candidates.</div>" | |
| btn_refresh_cand.click( | |
| safe_get_candidates_display, | |
| outputs=[live_candidates] | |
| ) | |
| # ========================================== | |
| # 7. LAUNCH | |
| # ========================================== | |
| if __name__ == "__main__": | |
| app.queue().launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| share=False, # Set to False for HF Spaces | |
| show_error=True, | |
| css=CUSTOM_CSS, | |
| js=FORCE_DARK_JS | |
| ) | |