Spaces:
Running
Running
| import os | |
| import time | |
| import gc | |
| import sys | |
| import threading | |
| from itertools import islice | |
| from datetime import datetime | |
| import re | |
| import gradio as gr | |
| import torch | |
| from transformers import pipeline, TextIteratorStreamer | |
| from transformers import AutoTokenizer | |
| from bs4 import BeautifulSoup | |
| import requests | |
| from urllib.parse import quote_plus | |
| import json | |
| import urllib.parse | |
| from config import MODELS | |
| # Global event to signal cancellation from the UI thread to the generation thread | |
| cancel_event = threading.Event() | |
| access_token = os.environ.get('HF_TOKEN', '') | |
| # Global cache for pipelines to avoid re-loading. | |
| PIPELINES = {} | |
| def google_search_web(query, max_results=6, max_chars=50): | |
| """Search using Google web scraping with multiple approaches""" | |
| # Try multiple User-Agents | |
| user_agents = [ | |
| 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36', | |
| 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36', | |
| 'Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36' | |
| ] | |
| for user_agent in user_agents: | |
| try: | |
| # Try different search URLs | |
| search_urls = [ | |
| f"https://www.google.com/search?q={quote_plus(query)}&safe=off&num={max_results}", | |
| f"https://www.google.com/search?q={quote_plus(query)}&safe=off&num={max_results}&hl=en", | |
| f"https://www.google.com/webhp?safe=off&q={quote_plus(query)}&num={max_results}" | |
| ] | |
| for search_url in search_urls: | |
| try: | |
| headers = { | |
| 'User-Agent': user_agent, | |
| 'Accept': 'text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,*/*;q=0.8', | |
| 'Accept-Language': 'en-US,en;q=0.5', | |
| 'Accept-Encoding': 'gzip, deflate', | |
| 'Connection': 'keep-alive', | |
| 'Upgrade-Insecure-Requests': '1', | |
| 'Cache-Control': 'max-age=0' | |
| } | |
| response = requests.get(search_url, headers=headers, timeout=15, verify=True) | |
| response.raise_for_status() | |
| soup = BeautifulSoup(response.text, 'html.parser') | |
| # Find search result containers | |
| results = [] | |
| # Try different selectors | |
| selectors = [ | |
| ('div', 'g'), | |
| ('div', 'tF2Cxc'), | |
| ('div', 'MjjYud'), | |
| ('div', 'yuRUbf') | |
| ] | |
| search_results = [] | |
| for tag, class_name in selectors: | |
| search_results = soup.find_all(tag, class_=class_name) | |
| if search_results: | |
| break | |
| if not search_results: | |
| # Try alternative parsing | |
| search_results = soup.find_all('div', class_=re.compile(r'^(g|tF2Cxc|MjjYud|yuRUbf)')) | |
| for result in search_results[:max_results]: | |
| try: | |
| # Get title | |
| title_elem = result.find('h3') | |
| if not title_elem: | |
| title_elem = result.find('h2') | |
| title = title_elem.text if title_elem else "No Title" | |
| # Get snippet | |
| snippet_elem = result.find('div', class_='VwiC3b') | |
| if not snippet_elem: | |
| snippet_elem = result.find('div', class_='IsZvec') | |
| if not snippet_elem: | |
| snippet_elem = result.find('div', class_='lEBKkf') | |
| snippet = snippet_elem.text if snippet_elem else "" | |
| # Get link | |
| link_elem = result.find('a') | |
| link = link_elem.get('href') if link_elem else "" | |
| if link and link.startswith('/url?q='): | |
| link = urllib.parse.unquote(link.split('/url?q=')[1].split('&')[0]) | |
| if link and not link.startswith('http'): | |
| continue | |
| # Clean up snippet | |
| snippet = ' '.join(snippet.split()) | |
| if len(snippet) > max_chars: | |
| snippet = snippet[:max_chars] + "..." | |
| if title and snippet: | |
| results.append(f"{len(results)+1}. {title} - {snippet}") | |
| except Exception: | |
| continue | |
| if results: | |
| return results | |
| except Exception: | |
| continue | |
| except Exception: | |
| continue | |
| return [] | |
| def duckduckgo_search(query, max_results=6, max_chars=50): | |
| """Fallback to DuckDuckGo search""" | |
| try: | |
| from ddgs import DDGS | |
| with DDGS() as ddgs: | |
| results = [] | |
| for r in islice(ddgs.text(query, region="wt-wt", safesearch="off", timelimit="y"), max_results): | |
| title = r.get('title', 'No Title') | |
| body = r.get('body', '') | |
| if len(body) > max_chars: | |
| body = body[:max_chars] + "..." | |
| results.append(f"{len(results)+1}. {title} - {body}") | |
| return results | |
| except Exception: | |
| return [] | |
| def bing_search(query, max_results=6, max_chars=50): | |
| """Fallback to Bing search""" | |
| try: | |
| headers = { | |
| 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36' | |
| } | |
| search_url = f"https://www.bing.com/search?q={quote_plus(query)}&safeSearch=off&count={max_results}" | |
| response = requests.get(search_url, headers=headers, timeout=10) | |
| response.raise_for_status() | |
| soup = BeautifulSoup(response.text, 'html.parser') | |
| results = [] | |
| # Find search results | |
| search_results = soup.find_all('li', class_='b_algo') | |
| for result in search_results[:max_results]: | |
| try: | |
| title_elem = result.find('h2') | |
| title = title_elem.text if title_elem else "No Title" | |
| snippet_elem = result.find('p') | |
| snippet = snippet_elem.text if snippet_elem else "" | |
| if len(snippet) > max_chars: | |
| snippet = snippet[:max_chars] + "..." | |
| if title and snippet: | |
| results.append(f"{len(results)+1}. {title} - {snippet}") | |
| except Exception: | |
| continue | |
| return results | |
| except Exception: | |
| return [] | |
| def retrieve_context(query, max_results=6, max_chars=50): | |
| """ | |
| Retrieve search snippets from multiple search engines. | |
| Returns a list of result strings. | |
| """ | |
| # Try Google first | |
| results = google_search_web(query, max_results, max_chars) | |
| if results: | |
| print(f"✅ Google search successful: {len(results)} results") | |
| return results | |
| # Try DuckDuckGo | |
| results = duckduckgo_search(query, max_results, max_chars) | |
| if results: | |
| print(f"✅ DuckDuckGo search successful: {len(results)} results") | |
| return results | |
| # Try Bing | |
| results = bing_search(query, max_results, max_chars) | |
| if results: | |
| print(f"✅ Bing search successful: {len(results)} results") | |
| return results | |
| print("❌ All search engines failed") | |
| return [] | |
| def load_pipeline(model_name): | |
| """ | |
| Load and cache a transformers pipeline for text generation. | |
| Tries bfloat16, falls back to float16 or float32 if unsupported. | |
| """ | |
| global PIPELINES | |
| if model_name in PIPELINES: | |
| return PIPELINES[model_name] | |
| repo = MODELS[model_name]["repo_id"] | |
| tokenizer = AutoTokenizer.from_pretrained(repo, token=access_token) | |
| for dtype in (torch.bfloat16, torch.float16, torch.float32): | |
| try: | |
| pipe = pipeline( | |
| task="text-generation", | |
| model=repo, | |
| tokenizer=tokenizer, | |
| trust_remote_code=True, | |
| dtype=dtype, | |
| device_map="auto", | |
| use_cache=True, | |
| token=access_token) | |
| PIPELINES[model_name] = pipe | |
| return pipe | |
| except Exception: | |
| continue | |
| # Final fallback | |
| pipe = pipeline( | |
| task="text-generation", | |
| model=repo, | |
| tokenizer=tokenizer, | |
| trust_remote_code=True, | |
| device_map="auto", | |
| use_cache=True | |
| ) | |
| PIPELINES[model_name] = pipe | |
| return pipe | |
| def format_conversation(history, system_prompt, tokenizer): | |
| if hasattr(tokenizer, "chat_template") and tokenizer.chat_template: | |
| messages = [{"role": "system", "content": system_prompt.strip()}] + history | |
| return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True, enable_thinking=True) | |
| else: | |
| # Fallback for base LMs without chat template | |
| prompt = system_prompt.strip() + "\n" | |
| for msg in history: | |
| if msg['role'] == 'user': | |
| prompt += "User: " + msg['content'].strip() + "\n" | |
| elif msg['role'] == 'assistant': | |
| prompt += "Assistant: " + msg['content'].strip() + "\n" | |
| if not prompt.strip().endswith("Assistant:"): | |
| prompt += "Assistant: " | |
| return prompt | |
| def get_duration(user_msg, chat_history, system_prompt, enable_search, max_results, max_chars, model_name, max_tokens, temperature, top_k, top_p, repeat_penalty, search_timeout): | |
| # Get model size from the MODELS dict | |
| model_size = MODELS[model_name].get("params_b", 4.0) | |
| # Only use AOT for models >= 2B parameters | |
| use_aot = model_size >= 2 | |
| # Adjusted for H200 performance | |
| base_duration = 20 if not use_aot else 40 | |
| token_duration = max_tokens * 0.005 | |
| search_duration = 10 if enable_search else 0 | |
| aot_compilation_buffer = 20 if use_aot else 0 | |
| return base_duration + token_duration + search_duration + aot_compilation_buffer | |
| def get_model_size(model_name): | |
| """Get model size from the MODELS dict.""" | |
| return MODELS.get(model_name, {}).get("params_b", 4.0) | |
| def chat_response(user_msg, chat_history, system_prompt, | |
| enable_search, max_results, max_chars, | |
| model_name, max_tokens, temperature, | |
| top_k, top_p, repeat_penalty, search_timeout): | |
| """ | |
| Generates streaming chat responses, optionally with background web search. | |
| This version includes cancellation support. | |
| """ | |
| # Clear the cancellation event at the start of a new generation | |
| cancel_event.clear() | |
| history = list(chat_history or []) | |
| history.append({'role': 'user', 'content': user_msg}) | |
| # Launch web search if enabled | |
| debug = '' | |
| search_results = [] | |
| if enable_search: | |
| debug = '🔍 Searching (Google → DuckDuckGo → Bing)...' | |
| thread_search = threading.Thread( | |
| target=lambda: search_results.extend( | |
| retrieve_context(user_msg, int(max_results), int(max_chars)) | |
| ) | |
| ) | |
| thread_search.daemon = True | |
| thread_search.start() | |
| else: | |
| debug = 'Web search disabled.' | |
| # Wait for search results if enabled | |
| if enable_search: | |
| thread_search.join(timeout=float(search_timeout)) | |
| if search_results: | |
| debug = f"✅ Search completed - Found {len(search_results)} results\n\n" + "\n".join( | |
| f"- {r}" for r in search_results | |
| ) | |
| else: | |
| debug = "❌ No search results found. Check internet connection or try again." | |
| try: | |
| cur_date = datetime.now().strftime('%Y-%m-%d') | |
| # Prepare enriched system prompt | |
| if search_results: | |
| enriched = system_prompt.strip() + f""" | |
| # SEARCH CONTEXT (TRUSTED SOURCES ONLY) | |
| Below are search results. Treat them as the ONLY source of truth for answering. | |
| {search_results} | |
| RULES (VERY IMPORTANT): | |
| - Do NOT use outside knowledge. Do NOT guess or fill missing information. | |
| - If the answer is not clearly supported by the search results, say: "Not enough information in the provided sources." | |
| - Every factual statement must be directly supported by at least one citation [citation:X]. | |
| - Do NOT add explanations, examples, or background that are not explicitly present in the sources. | |
| - Do NOT paraphrase beyond what is necessary for clarity. | |
| - If sources conflict, mention the conflict and cite both. | |
| - If multiple sources are used, distribute citations per sentence, not only at the end. | |
| CITATION RULES: | |
| - Use inline citations like this: [citation:1] | |
| - If multiple sources support a sentence: [citation:1][citation:3] | |
| - Never place all citations only at the end. | |
| ANSWER POLICY: | |
| - Be concise and strictly grounded. | |
| - No speculation, no assumptions, no "likely", no "probably". | |
| - If the user requests a list, only include items explicitly found in sources. | |
| - If sources are insufficient, stop and ask for more data instead of guessing. | |
| DATE CONTEXT: | |
| - Today is {cur_date} (use only for time reference, not for assumptions). | |
| USER QUESTION: | |
| """ | |
| else: | |
| enriched = system_prompt.strip() | |
| pipe = load_pipeline(model_name) | |
| prompt = format_conversation(history, enriched, pipe.tokenizer) | |
| prompt_debug = f"\n\n--- Prompt Preview ---\n```\n{prompt}\n```" | |
| streamer = TextIteratorStreamer(pipe.tokenizer, | |
| skip_prompt=True, | |
| skip_special_tokens=True) | |
| gen_thread = threading.Thread( | |
| target=pipe, | |
| args=(prompt,), | |
| kwargs={ | |
| 'max_new_tokens': max_tokens, | |
| 'temperature': temperature, | |
| 'top_k': top_k, | |
| 'top_p': top_p, | |
| 'repetition_penalty': repeat_penalty, | |
| 'streamer': streamer, | |
| 'return_full_text': False, | |
| } | |
| ) | |
| gen_thread.start() | |
| # Buffers for thought vs answer | |
| thought_buf = '' | |
| answer_buf = '' | |
| in_thought = False | |
| assistant_message_started = False | |
| # First yield contains the user message | |
| yield history, debug | |
| # Stream tokens | |
| for chunk in streamer: | |
| # Check for cancellation signal | |
| if cancel_event.is_set(): | |
| if assistant_message_started and history and history[-1]['role'] == 'assistant': | |
| history[-1]['content'] += " [Generation Canceled]" | |
| yield history, debug | |
| break | |
| text = chunk | |
| # Detect start of thinking | |
| if not in_thought and '<think>' in text: | |
| in_thought = True | |
| history.append({'role': 'assistant', 'content': '', 'metadata': {'title': '💭 Thought'}}) | |
| assistant_message_started = True | |
| after = text.split('<think>', 1)[1] | |
| thought_buf += after | |
| if '</think>' in thought_buf: | |
| before, after2 = thought_buf.split('</think>', 1) | |
| history[-1]['content'] = before.strip() | |
| in_thought = False | |
| answer_buf = after2 | |
| history.append({'role': 'assistant', 'content': answer_buf}) | |
| else: | |
| history[-1]['content'] = thought_buf | |
| yield history, debug | |
| continue | |
| if in_thought: | |
| thought_buf += text | |
| if '</think>' in thought_buf: | |
| before, after2 = thought_buf.split('</think>', 1) | |
| history[-1]['content'] = before.strip() | |
| in_thought = False | |
| answer_buf = after2 | |
| history.append({'role': 'assistant', 'content': answer_buf}) | |
| else: | |
| history[-1]['content'] = thought_buf | |
| yield history, debug | |
| continue | |
| # Stream answer | |
| if not assistant_message_started: | |
| history.append({'role': 'assistant', 'content': ''}) | |
| assistant_message_started = True | |
| answer_buf += text | |
| history[-1]['content'] = answer_buf.strip() | |
| yield history, debug | |
| gen_thread.join() | |
| yield history, debug + prompt_debug | |
| except GeneratorExit: | |
| # Handle cancellation gracefully | |
| print("Chat response cancelled.") | |
| return | |
| except Exception as e: | |
| history.append({'role': 'assistant', 'content': f"Error: {e}"}) | |
| yield history, debug | |
| finally: | |
| gc.collect() | |
| def update_default_prompt(enable_search): | |
| return f"You are a helpful assistant." | |
| def update_duration_estimate(model_name, enable_search, max_results, max_chars, max_tokens, search_timeout): | |
| """Calculate and format the estimated GPU duration for current settings.""" | |
| try: | |
| dummy_msg, dummy_history, dummy_system_prompt = "", [], "" | |
| duration = get_duration(dummy_msg, dummy_history, dummy_system_prompt, | |
| enable_search, max_results, max_chars, model_name, | |
| max_tokens, 0.7, 40, 0.9, 1.2, search_timeout) | |
| model_size = get_model_size(model_name) | |
| return (f"⏱️ **Estimated GPU Time: {duration:.1f} seconds**\n\n" | |
| f"📊 **Model Size:** {model_size:.1f}B parameters\n" | |
| f"🔍 **Web Search:** {'Enabled (Multi-Engine)' if enable_search else 'Disabled'}") | |
| except Exception as e: | |
| return f"⚠️ Error calculating estimate: {e}" | |
| # ------------------------------ | |
| # Gradio UI | |
| # ------------------------------ | |
| with gr.Blocks( | |
| title="LLM Inference", | |
| theme=gr.themes.Soft( | |
| primary_hue="blue", | |
| secondary_hue="blue", | |
| neutral_hue="slate", | |
| radius_size="lg", | |
| font=[gr.themes.GoogleFont("Syne"), "Arial", "sans-serif"] | |
| ), | |
| css=""" | |
| .duration-estimate { background: linear-gradient(135deg, #667eea15 0%, #764ba215 100%); border-left: 4px solid #667eea; padding: 12px; border-radius: 8px; margin: 16px 0; } | |
| .chatbot { border-radius: 12px; box-shadow: 0 4px 6px -1px rgba(0, 0, 0, 0.1); } | |
| button.primary { font-weight: 600; } | |
| .gradio-accordion { margin-bottom: 12px; } | |
| """ | |
| ) as demo: | |
| # Header | |
| gr.Markdown(""" | |
| # 🧠 LLM Inference with Multi-Engine Search | |
| """) | |
| with gr.Row(): | |
| # Left Panel - Configuration | |
| with gr.Column(scale=3): | |
| # Core Settings (Always Visible) | |
| with gr.Group(): | |
| gr.Markdown("### ⚙️ Core Settings") | |
| model_dd = gr.Dropdown( | |
| label="🤖 Model", | |
| choices=list(MODELS.keys()), | |
| value="Qwen3-1.7B", | |
| info="Select the language model to use" | |
| ) | |
| search_chk = gr.Checkbox( | |
| label="🔍 Enable Web Search", | |
| value=False, | |
| info="Search across Google, DuckDuckGo, and Bing (no API required)" | |
| ) | |
| sys_prompt = gr.Textbox(label="📝 System Prompt", lines=3, value=update_default_prompt(False), placeholder="Define the assistant's behavior and personality...") | |
| # Duration Estimate | |
| duration_display = gr.Markdown( | |
| value=update_duration_estimate("Qwen3-1.7B", False, 4, 50, 1024, 5.0), | |
| elem_classes="duration-estimate" | |
| ) | |
| # Advanced Settings (Collapsible) | |
| with gr.Accordion("🎛️ Advanced Generation Parameters", open=False): | |
| max_tok = gr.Slider( | |
| 64, 16384, value=1024, step=32, | |
| label="Max Tokens", | |
| info="Maximum length of generated response" | |
| ) | |
| temp = gr.Slider( | |
| 0.1, 2.0, value=0.7, step=0.1, | |
| label="Temperature", | |
| info="Higher = more creative, Lower = more focused" | |
| ) | |
| with gr.Row(): | |
| k = gr.Slider( | |
| 1, 100, value=40, step=1, | |
| label="Top-K", | |
| info="Number of top tokens to consider" | |
| ) | |
| p = gr.Slider( | |
| 0.1, 1.0, value=0.9, step=0.05, | |
| label="Top-P", | |
| info="Nucleus sampling threshold" | |
| ) | |
| rp = gr.Slider( | |
| 1.0, 2.0, value=1.2, step=0.1, | |
| label="Repetition Penalty", | |
| info="Penalize repeated tokens" | |
| ) | |
| # Web Search Settings (Collapsible) | |
| with gr.Accordion("🌐 Web Search Settings", open=False, visible=False) as search_settings: | |
| mr = gr.Number( | |
| value=4, precision=0, | |
| label="Max Results", | |
| info="Number of search results to retrieve" | |
| ) | |
| mc = gr.Number( | |
| value=50, precision=0, | |
| label="Max Chars/Result", | |
| info="Character limit per search result" | |
| ) | |
| st = gr.Slider( | |
| minimum=0.0, maximum=30.0, step=0.5, value=5.0, | |
| label="Search Timeout (s)", | |
| info="Maximum time to wait for search results" | |
| ) | |
| gr.Markdown(""" | |
| ⚠️ **Search Engines:** | |
| - Google (primary) | |
| - DuckDuckGo (fallback) | |
| - Bing (fallback) | |
| SafeSearch is **OFF** for comprehensive results. | |
| """) | |
| # Actions | |
| with gr.Row(): | |
| clr = gr.Button("🗑️ Clear Chat", variant="secondary", scale=1) | |
| # Right Panel - Chat Interface | |
| with gr.Column(scale=7): | |
| chat = gr.Chatbot( | |
| type="messages", | |
| height=600, | |
| label="💬 Conversation", | |
| show_copy_button=True, | |
| avatar_images=( | |
| "data:image/svg+xml,%3Csvg xmlns='http://www.w3.org/2000/svg' width='40' height='40'%3E%3Crect width='40' height='40' rx='20' fill='%23f093fb'/%3E%3Ctext x='20' y='28' text-anchor='middle' font-size='20' fill='white' font-family='Arial'%3E👤%3C/text%3E%3C/svg%3E", | |
| "data:image/svg+xml,%3Csvg xmlns='http://www.w3.org/2000/svg' width='40' height='40'%3E%3Crect width='40' height='40' rx='20' fill='%23667eea'/%3E%3Ctext x='20' y='28' text-anchor='middle' font-size='20' fill='white' font-family='Arial'%3E🤖%3C/text%3E%3C/svg%3E" | |
| ), | |
| bubble_full_width=False, | |
| render_markdown=True, | |
| sanitize_html=False | |
| ) | |
| # Input Area | |
| with gr.Row(): | |
| txt = gr.Textbox( | |
| placeholder="💭 Type your message here... (Press Enter to send)", | |
| scale=9, | |
| container=False, | |
| show_label=False, | |
| lines=1, | |
| max_lines=5 | |
| ) | |
| with gr.Column(scale=1, min_width=120): | |
| submit_btn = gr.Button("📤 Send", variant="primary", size="lg") | |
| cancel_btn = gr.Button("⏹️ Stop", variant="stop", visible=False, size="lg") | |
| # Example Prompts | |
| gr.Examples( | |
| examples=[ | |
| ["Explain quantum computing in simple terms"], | |
| ["Write a Python function to calculate fibonacci numbers"], | |
| ["What are the latest developments in AI? (Enable web search)"], | |
| ["Tell me a creative story about a time traveler"], | |
| ["Help me debug this code: def add(a,b): return a+b+1"] | |
| ], | |
| inputs=txt, | |
| label="💡 Example Prompts" | |
| ) | |
| # Debug/Status Info (Collapsible) | |
| with gr.Accordion("🔍 Debug Info", open=False): | |
| dbg = gr.Markdown() | |
| # Footer | |
| gr.Markdown(""" | |
| --- | |
| 💡 **Tips:** | |
| - Use **Advanced Parameters** to fine-tune creativity and response length | |
| - Enable **Web Search** for real-time information (uses multiple search engines) | |
| - SafeSearch is **OFF** for comprehensive results | |
| - Try different **models** for various tasks (reasoning, coding, general chat) | |
| - Click the **Copy** button on responses to save them to your clipboard | |
| """, elem_classes="footer") | |
| # --- Event Listeners --- | |
| # Group all inputs for cleaner event handling | |
| chat_inputs = [txt, chat, sys_prompt, search_chk, mr, mc, model_dd, max_tok, temp, k, p, rp, st] | |
| # Group all UI components that can be updated. | |
| ui_components = [chat, dbg, txt, submit_btn, cancel_btn] | |
| def submit_and_manage_ui(user_msg, chat_history, *args): | |
| """ | |
| Orchestrator function that manages UI state and calls the backend chat function. | |
| """ | |
| if not user_msg.strip(): | |
| yield {} | |
| return | |
| # Update UI to "generating" state | |
| yield { | |
| txt: gr.update(value="", interactive=False), | |
| submit_btn: gr.update(interactive=False), | |
| cancel_btn: gr.update(visible=True), | |
| } | |
| cancelled = False | |
| try: | |
| backend_args = [user_msg, chat_history] + list(args) | |
| for response_chunk in chat_response(*backend_args): | |
| yield { | |
| chat: response_chunk[0], | |
| dbg: response_chunk[1], | |
| } | |
| except GeneratorExit: | |
| cancelled = True | |
| print("Generation cancelled by user.") | |
| raise | |
| except Exception as e: | |
| print(f"An error occurred during generation: {e}") | |
| error_history = (chat_history or []) + [ | |
| {'role': 'user', 'content': user_msg}, | |
| {'role': 'assistant', 'content': f"**An error occurred:** {str(e)}"} | |
| ] | |
| yield {chat: error_history} | |
| finally: | |
| if not cancelled: | |
| print("Resetting UI state.") | |
| yield { | |
| txt: gr.update(interactive=True), | |
| submit_btn: gr.update(interactive=True), | |
| cancel_btn: gr.update(visible=False), | |
| } | |
| def set_cancel_flag(): | |
| """Called by the cancel button, sets the global event.""" | |
| cancel_event.set() | |
| print("Cancellation signal sent.") | |
| def reset_ui_after_cancel(): | |
| """Reset UI components after cancellation.""" | |
| cancel_event.clear() | |
| print("UI reset after cancellation.") | |
| return { | |
| txt: gr.update(interactive=True), | |
| submit_btn: gr.update(interactive=True), | |
| cancel_btn: gr.update(visible=False), | |
| } | |
| # Event for submitting text via Enter key or Submit button | |
| submit_event = txt.submit( | |
| fn=submit_and_manage_ui, | |
| inputs=chat_inputs, | |
| outputs=ui_components, | |
| ) | |
| submit_btn.click( | |
| fn=submit_and_manage_ui, | |
| inputs=chat_inputs, | |
| outputs=ui_components, | |
| ) | |
| # Event for the "Cancel" button. | |
| cancel_btn.click( | |
| fn=set_cancel_flag, | |
| cancels=[submit_event] | |
| ).then( | |
| fn=reset_ui_after_cancel, | |
| outputs=ui_components | |
| ) | |
| # Listeners for updating the duration estimate | |
| duration_inputs = [model_dd, search_chk, mr, mc, max_tok, st] | |
| for component in duration_inputs: | |
| component.change(fn=update_duration_estimate, inputs=duration_inputs, outputs=duration_display) | |
| # Toggle web search settings visibility | |
| def toggle_search_settings(enabled): | |
| return gr.update(visible=enabled) | |
| search_chk.change( | |
| fn=lambda enabled: (update_default_prompt(enabled), gr.update(visible=enabled)), | |
| inputs=search_chk, | |
| outputs=[sys_prompt, search_settings] | |
| ) | |
| # Clear chat action | |
| clr.click(fn=lambda: ([], "", ""), outputs=[chat, txt, dbg]) | |
| demo.launch() |