diff --git "a/psyllm.py" "b/psyllm.py" --- "a/psyllm.py" +++ "b/psyllm.py" @@ -1,5 +1,21 @@ import os os.environ["TOKENIZERS_PARALLELISM"] = "false" + +# API Key Configuration - Set your API keys here or as environment variables +# You can also set these as environment variables: MISTRAL_API_KEY, OPENAI_API_KEY, etc. +MISTRAL_API_KEY = os.environ.get("MISTRAL_API_KEY", "") # Set your Mistral API key here +OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY", "") # Set your OpenAI API key here +NEBIUS_API_KEY = os.environ.get("NEBIUS_API_KEY", "") # Set your Nebius API key here +GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY", "") # Set your Gemini API key here +ANTHROPIC_API_KEY = os.environ.get("ANTHROPIC_API_KEY", "") # Set your Anthropic API key here +GROK_API_KEY = os.environ.get("GROK_API_KEY", "") # Set your Grok API key here +HF_API_TOKEN = os.environ.get("HF_API_TOKEN", "") # Set your HuggingFace API token here + +# If you want to set API keys directly in the code, uncomment and modify the lines below: +# MISTRAL_API_KEY = "your_mistral_api_key_here" +# OPENAI_API_KEY = "your_openai_api_key_here" +# NEBIUS_API_KEY = "your_nebius_api_key_here" + import datetime import functools import traceback @@ -25,25 +41,48 @@ import gradio as gr import requests from pydantic import PrivateAttr import pydantic +import zipfile +import mimetypes from langchain.llms.base import LLM from typing import Any, Optional, List import typing import time +import sys +import csv +import statistics +import re -print("Pydantic Version: ") -print(pydantic.__version__) -# Add Mistral imports with fallback handling - +# Add OpenAI import for NEBIUS with version check +try: + import openai + from importlib.metadata import version as pkg_version + openai_version = pkg_version("openai") + print(f"OpenAI import success, version: {openai_version}") + if tuple(map(int, openai_version.split("."))) < (1, 0, 0): + print("ERROR: openai version must be >= 1.0.0 for NEBIUS support. Please upgrade with: pip install --upgrade openai") + sys.exit(1) + from openai import OpenAI + OPENAI_AVAILABLE = True +except ImportError as e: + OPENAI_AVAILABLE = False + print("OpenAI import failed:", e) +except Exception as e: + print("OpenAI version check failed:", e) + sys.exit(1) + +# Add Mistral import with better error handling try: from mistralai import Mistral MISTRAL_AVAILABLE = True - debug_print = lambda msg: print(f"[{datetime.datetime.now().isoformat()}] {msg}") - debug_print("Loaded latest Mistral client library") -except ImportError: + print("Mistral import success") +except ImportError as e: + MISTRAL_AVAILABLE = False + print("Mistral import failed:", e) + print("Please install mistralai package with: pip install mistralai") +except Exception as e: MISTRAL_AVAILABLE = False - debug_print = lambda msg: print(f"[{datetime.datetime.now().isoformat()}] {msg}") - debug_print("Mistral client library not found. Install with: pip install mistralai") + print("Mistral import error:", e) def debug_print(message: str): print(f"[{datetime.datetime.now().isoformat()}] {message}", flush=True) @@ -134,70 +173,59 @@ def load_pdfs_async(file_links, model_choice, prompt_template, bm25_weight, temp get_job_list() # Return updated job list ) -def submit_query_async(query, use_llama, use_mistral, temperature, top_p): +def submit_query_async(query, model1, model2, temperature, top_p, top_k, max_tokens): """Asynchronous version of submit_query_updated to prevent timeouts""" global last_job_id if not query: return ("Please enter a non-empty query", "Input/Output tokens: 0/0", "Please enter a non-empty query", "Input/Output tokens: 0/0", "", "", get_job_list()) - - if not (use_llama or use_mistral): + if not (model1 or model2): return ("Please select at least one model", "Input/Output tokens: 0/0", "Please select at least one model", "Input/Output tokens: 0/0", "", "", get_job_list()) - - responses = {"llama": None, "mistral": None} + responses = {"model1": None, "model2": None} job_ids = [] - - if use_llama: - llama_job_id = str(uuid.uuid4()) - debug_print(f"Starting async job {llama_job_id} for Llama query: {query}") - - # Start background thread for Llama + if model1: + model1_job_id = str(uuid.uuid4()) + debug_print(f"Starting async job {model1_job_id} for Model 1: {model1}") threading.Thread( target=process_in_background, - args=(llama_job_id, submit_query_updated, [query, "πΊπΈ Remote Meta-Llama-3", temperature, top_p]) + args=(model1_job_id, submit_query_updated, [query, model1, temperature, top_p, top_k, max_tokens]) ).start() - - jobs[llama_job_id] = { + jobs[model1_job_id] = { "status": "processing", "type": "query", "start_time": time.time(), "query": query, - "model": "Llama" + "model": model1, + "model_position": "model1" } - job_ids.append(llama_job_id) - responses["llama"] = f"Processing (Job ID: {llama_job_id})" - - if use_mistral: - mistral_job_id = str(uuid.uuid4()) - debug_print(f"Starting async job {mistral_job_id} for Mistral query: {query}") - - # Start background thread for Mistral + job_ids.append(model1_job_id) + responses["model1"] = f"Processing (Job ID: {model1_job_id})" + if model2: + model2_job_id = str(uuid.uuid4()) + debug_print(f"Starting async job {model2_job_id} for Model 2: {model2}") threading.Thread( target=process_in_background, - args=(mistral_job_id, submit_query_updated, [query, "πͺπΊ Mistral-API", temperature, top_p]) + args=(model2_job_id, submit_query_updated, [query, model2, temperature, top_p, top_k, max_tokens]) ).start() - - jobs[mistral_job_id] = { + jobs[model2_job_id] = { "status": "processing", "type": "query", "start_time": time.time(), "query": query, - "model": "Mistral" + "model": model2, + "model_position": "model2" } - job_ids.append(mistral_job_id) - responses["mistral"] = f"Processing (Job ID: {mistral_job_id})" - - # Store the last job ID (use the first one for now) + job_ids.append(model2_job_id) + responses["model2"] = f"Processing (Job ID: {model2_job_id})" last_job_id = job_ids[0] if job_ids else None - return ( - responses.get("llama", "Not selected"), - "Input tokens: " + str(count_tokens(query)) if use_llama else "Not selected", - responses.get("mistral", "Not selected"), - "Input tokens: " + str(count_tokens(query)) if use_mistral else "Not selected", + responses.get("model1", "Not selected"), + "Input tokens: " + str(count_tokens(query)) if model1 else "Not selected", + responses.get("model2", "Not selected"), + "Input tokens: " + str(count_tokens(query)) if model2 else "Not selected", last_job_id, query, get_job_list() @@ -270,7 +298,8 @@ def sync_model_dropdown(value): # Function to check job status def check_job_status(job_id): if not job_id: - return "Please enter a job ID", "", "", "", "" + # Always return 9 outputs (pad with empty strings) + return "Please enter a job ID", "", "", "", "", "", "", "", "" # Process any completed jobs in the queue try: @@ -286,11 +315,17 @@ def check_job_status(job_id): # Check if the requested job exists if job_id not in jobs: - return "Job not found. Please check the ID and try again.", "", "", "", "" + return "Job not found. Please check the ID and try again.", "", "", "", "", "", "", "", "" job = jobs[job_id] job_query = job.get("query", "No query available for this job") + # Get model response updates + model1_resp, model1_tok, model2_resp, model2_tok = update_model_responses_from_jobs() + + # Generate detailed status report + status_report = generate_detailed_job_status(job_id, job) + # If job is still processing if job["status"] == "processing": elapsed_time = time.time() - job["start_time"] @@ -298,21 +333,27 @@ def check_job_status(job_id): if job_type == "load_files": return ( - f"Files are still being processed (elapsed: {elapsed_time:.1f}s).\n\n" - f"Try checking again in a few seconds.", + status_report, f"Job ID: {job_id}", f"Status: Processing", "", - job_query + job_query, + model1_resp, + model1_tok, + model2_resp, + model2_tok ) else: # query job return ( - f"Query is still being processed (elapsed: {elapsed_time:.1f}s).\n\n" - f"Try checking again in a few seconds.", + status_report, f"Job ID: {job_id}", f"Input tokens: {count_tokens(job.get('query', ''))}", "Output tokens: pending", - job_query + job_query, + model1_resp, + model1_tok, + model2_resp, + model2_tok ) # If job is completed @@ -322,23 +363,160 @@ def check_job_status(job_id): if job.get("type") == "load_files": return ( - f"{result[0]}\n\nProcessing time: {processing_time:.1f}s", + status_report, result[1], result[2], "", - job_query + job_query, + model1_resp, + model1_tok, + model2_resp, + model2_tok ) else: # query job + # Defensive: pad result to at least 4 elements + r = list(result) if isinstance(result, (list, tuple)) else [result] + while len(r) < 4: + r.append("") return ( - f"{result[0]}\n\nProcessing time: {processing_time:.1f}s", - result[1], - result[2], - result[3], - job_query + status_report, + r[1], + r[2], + r[3], + job_query, + model1_resp, + model1_tok, + model2_resp, + model2_tok ) # Fallback for unknown status - return f"Job status: {job['status']}", "", "", "", job_query + return status_report, "", "", "", job_query, model1_resp, model1_tok, model2_resp, model2_tok + +def generate_detailed_job_status(job_id, job): + """Generate detailed status report for a job showing model processing information""" + if not job: + return "Job not found" + + job_type = job.get("type", "unknown") + status = job.get("status", "unknown") + query = job.get("query", "") + model = job.get("model", "") + start_time = job.get("start_time", 0) + end_time = job.get("end_time", 0) + + report = f"## Job Status Report\n\n" + report += f"**Job ID:** {job_id}\n" + report += f"**Type:** {job_type}\n" + report += f"**Status:** {status}\n" + report += f"**Query:** {query[:100]}{'...' if len(query) > 100 else ''}\n\n" + + if job_type == "query": + # Find all jobs with the same query to show parallel processing + related_jobs = [(jid, jinfo) for jid, jinfo in jobs.items() + if jinfo.get("query") == query and jinfo.get("type") == "query"] + + report += f"## Model Processing Status\n\n" + + for jid, jinfo in related_jobs: + jmodel = jinfo.get("model", "Unknown") + jstatus = jinfo.get("status", "unknown") + jstart = jinfo.get("start_time", 0) + jend = jinfo.get("end_time", 0) + + if jstatus == "processing": + elapsed = time.time() - jstart + report += f"**{jmodel}:** β³ Processing (elapsed: {elapsed:.1f}s)\n" + elif jstatus == "completed": + elapsed = jend - jstart + result = jinfo.get("result", ("", "", "", "")) + input_tokens = result[1] if len(result) > 1 else "N/A" + output_tokens = result[2] if len(result) > 2 else "N/A" + report += f"**{jmodel}:** β Completed (time: {elapsed:.1f}s, {input_tokens}, {output_tokens})\n" + else: + report += f"**{jmodel}:** β {jstatus}\n" + + # Add summary + completed_jobs = [j for j in related_jobs if j[1].get("status") == "completed"] + processing_jobs = [j for j in related_jobs if j[1].get("status") == "processing"] + + report += f"\n## Summary\n" + report += f"- **Total models:** {len(related_jobs)}\n" + report += f"- **Completed:** {len(completed_jobs)}\n" + report += f"- **Processing:** {len(processing_jobs)}\n" + + if completed_jobs: + total_time = sum(j[1].get("end_time", 0) - j[1].get("start_time", 0) for j in completed_jobs) + report += f"- **Total processing time:** {total_time:.1f}s\n" + + elif job_type == "load_files": + if status == "processing": + elapsed = time.time() - start_time + report += f"**File loading in progress** (elapsed: {elapsed:.1f}s)\n" + elif status == "completed": + elapsed = end_time - start_time + report += f"**File loading completed** (time: {elapsed:.1f}s)\n" + + return report + +def update_model_responses_from_jobs(): + """Update Model 1 and Model 2 response fields based on completed jobs""" + global last_job_id + + # Process any completed jobs in the queue + try: + while not results_queue.empty(): + completed_id, result = results_queue.get_nowait() + if completed_id in jobs: + jobs[completed_id]["status"] = "completed" + jobs[completed_id]["result"] = result + jobs[completed_id]["end_time"] = time.time() + debug_print(f"Job {completed_id} completed and stored in jobs dictionary") + except queue.Empty: + pass + + # Find completed query jobs and organize by model position + model1_jobs = [(job_id, job_info) for job_id, job_info in jobs.items() + if job_info.get("type") == "query" and job_info.get("status") == "completed" + and job_info.get("model_position") == "model1"] + model2_jobs = [(job_id, job_info) for job_id, job_info in jobs.items() + if job_info.get("type") == "query" and job_info.get("status") == "completed" + and job_info.get("model_position") == "model2"] + + # Sort by completion time (most recent first) + model1_jobs.sort(key=lambda x: x[1].get("end_time", 0), reverse=True) + model2_jobs.sort(key=lambda x: x[1].get("end_time", 0), reverse=True) + + model1_response = "No completed jobs found" + model1_tokens = "Input/Output tokens: 0/0" + model2_response = "No completed jobs found" + model2_tokens = "Input/Output tokens: 0/0" + + if model1_jobs: + # Get the most recent Model 1 job + job_id, job_info = model1_jobs[0] + result = job_info.get("result", ("", "", "", "")) + model_name = job_info.get("model", "Unknown Model") + response_text = result[0] if len(result) > 0 else "No response" + input_tokens = result[1] if len(result) > 1 else "Input tokens: 0" + output_tokens = result[2] if len(result) > 2 else "Output tokens: 0" + + model1_response = f"Model: {model_name}\n{input_tokens} | {output_tokens}\n\n{response_text}" + model1_tokens = f"{input_tokens} | {output_tokens}" + + if model2_jobs: + # Get the most recent Model 2 job + job_id, job_info = model2_jobs[0] + result = job_info.get("result", ("", "", "", "")) + model_name = job_info.get("model", "Unknown Model") + response_text = result[0] if len(result) > 0 else "No response" + input_tokens = result[1] if len(result) > 1 else "Input tokens: 0" + output_tokens = result[2] if len(result) > 2 else "Output tokens: 0" + + model2_response = f"Model: {model_name}\n{input_tokens} | {output_tokens}\n\n{response_text}" + model2_tokens = f"{input_tokens} | {output_tokens}" + + return model1_response, model1_tokens, model2_response, model2_tokens # Function to clean up old jobs def cleanup_old_jobs(): @@ -414,181 +592,581 @@ def load_txt_from_url(url: str) -> Document: else: raise Exception(f"Failed to load {url} with status {response.status_code}") -class RemoteLLM(LLM): +# --- Model List for Dropdowns --- +# Each entry: display, backend, provider +models = [ + # NEBIUS + {"display": "π¦ GPT OSS 120b (Nebius)", "backend": "openai/gpt-oss-120b", "provider": "nebius"}, + {"display": "π¦ GPT OSS 20b (Nebius)", "backend": "openai/gpt-oss-20b", "provider": "nebius"}, + {"display": "π¦ Google Gemma 3 27b-Instruct (Nebius)", "backend": "google/gemma-3-27b-it ", "provider": "nebius"}, + {"display": "π¦ Kimi K2", "backend": "moonshotai/Kimi-K2-Instruct", "provider": "nebius"}, + {"display": "π¦ DeepSeek-R1-0528 (Nebius)", "backend": "deepseek-ai/DeepSeek-R1-0528", "provider": "nebius"}, + {"display": "π¦ DeepSeek-V3-0324 (Nebius)", "backend": "deepseek-ai/DeepSeek-V3-0324", "provider": "nebius"}, + {"display": "π¦ DeepSeek-V3 (Nebius)", "backend": "deepseek-ai/DeepSeek-V3", "provider": "nebius"}, + {"display": "π¦ DeepSeek-R1-Distill-Llama-70B (Nebius)", "backend": "deepseek-ai/DeepSeek-R1-Distill-Llama-70B", "provider": "nebius"}, + {"display": "π¦ Meta-Llama-3.3-70B-Instruct (Nebius)", "backend": "meta-llama/Llama-3.3-70B-Instruct", "provider": "nebius"}, + {"display": "π¦ Meta-Llama-3.1-8B-Instruct (Nebius)", "backend": "meta-llama/Meta-Llama-3.1-8B-Instruct", "provider": "nebius"}, + {"display": "π¦ Meta-Llama-3.1-70B-Instruct (Nebius)", "backend": "meta-llama/Meta-Llama-3.1-70B-Instruct", "provider": "nebius"}, + {"display": "π¦ Meta-Llama-3.1-405B-Instruct (Nebius)", "backend": "meta-llama/Meta-Llama-3.1-405B-Instruct", "provider": "nebius"}, + {"display": "π¦ NVIDIA Llama-3_1-Nemotron-Ultra-253B-v1 (Nebius)", "backend": "nvidia/Llama-3_1-Nemotron-Ultra-253B-v1", "provider": "nebius"}, + {"display": "π¦ NVIDIA Llama-3_3-Nemotron-Super-49B-v1 (Nebius)", "backend": "nvidia/Llama-3_3-Nemotron-Super-49B-v1", "provider": "nebius"}, + {"display": "π¦ Mistral-Nemo-Instruct-2407 (Nebius)", "backend": "mistralai/Mistral-Nemo-Instruct-2407", "provider": "nebius"}, + {"display": "π¦ Hermes 4 405B (Nebius)", "backend": "NousResearch/Hermes-4-405B", "provider": "nebius"}, + {"display": "π¦ Hermes 4 70B (Nebius)", "backend": "NousResearch/Hermes-4-70B", "provider": "nebius"}, + {"display": "π¦ GLM-4.5 (Nebius)", "backend": "zai-org/GLM-4.5", "provider": "nebius"}, + {"display": "π¦ GLM-4.5 AIR (Nebius)", "backend": "zai-org/GLM-4.5-Air", "provider": "nebius"}, + {"display": "π¦ Qwen3-235B-A22B (Nebius)", "backend": "Qwen/Qwen3-235B-A22B", "provider": "nebius"}, + {"display": "π¦ Qwen3-30B-A3B (Nebius)", "backend": "Qwen/Qwen3-30B-A3B", "provider": "nebius"}, + {"display": "π¦ Qwen3-32B (Nebius)", "backend": "Qwen/Qwen3-32B", "provider": "nebius"}, + {"display": "π¦ Qwen3-14B (Nebius)", "backend": "Qwen/Qwen3-14B", "provider": "nebius"}, + {"display": "π¦ Qwen3-4B-fast (Nebius)", "backend": "Qwen/Qwen3-4B-fast", "provider": "nebius"}, + {"display": "π¦ QwQ-32B (Nebius)", "backend": "Qwen/QwQ-32B", "provider": "nebius"}, + {"display": "π¦ Google Gemma-2-2b-it (Nebius)", "backend": "google/gemma-2-2b-it", "provider": "nebius"}, + {"display": "π¦ Google Gemma-2-9b-it (Nebius)", "backend": "google/gemma-2-9b-it", "provider": "nebius"}, + {"display": "π¦ Hermes-3-Llama-405B (Nebius)", "backend": "NousResearch/Hermes-3-Llama-405B", "provider": "nebius"}, + {"display": "π¦ Llama3-OpenBioLLM-70B (Nebius, Medical)", "backend": "aaditya/Llama3-OpenBioLLM-70B", "provider": "nebius"}, + {"display": "π¦ Qwen2.5-72B-Instruct (Nebius, Code)", "backend": "Qwen/Qwen2.5-72B-Instruct", "provider": "nebius"}, + {"display": "π¦ Qwen2.5-Coder-7B (Nebius, Code)", "backend": "Qwen/Qwen2.5-Coder-7B", "provider": "nebius"}, + {"display": "π¦ Qwen2.5-Coder-32B-Instruct (Nebius, Code)", "backend": "Qwen/Qwen2.5-Coder-32B-Instruct", "provider": "nebius"}, + # HuggingFace + {"display": "π€ Remote Meta-Llama-3 (HuggingFace)", "backend": "meta-llama/Meta-Llama-3-8B-Instruct", "provider": "hf_inference"}, + {"display": "π€ SciFive PubMed Classifier", "backend": "razent/SciFive-base-Pubmed_PMC", "provider": "hf_inference"}, + {"display": "π€ Tiny GPT-2 Classifier", "backend": "ydshieh/tiny-random-GPT2ForSequenceClassification", "provider": "hf_inference"}, + {"display": "π€ ArabianGPT QA (0.4B)", "backend": "gp-tar4/QA_FineTuned_ArabianGPT-03B", "provider": "hf_inference"}, + {"display": "π€ Tiny Mistral Classifier", "backend": "xshubhamx/tiny-mistral", "provider": "hf_inference"}, + {"display": "π€ Hallucination Scorer", "backend": "tcapelle/hallu_scorer", "provider": "hf_inference"}, + {"display": "πͺπΊ Mistral-API (Mistral)", "backend": "mistral-small-latest", "provider": "mistral"}, + # OpenAI + {"display": "πΊπΈ GPT-3.5 (OpenAI)", "backend": "gpt-3.5-turbo", "provider": "openai"}, + {"display": "πΊπΈ GPT-4o (OpenAI)", "backend": "gpt-4o", "provider": "openai"}, + {"display": "πΊπΈ GPT-4o mini (OpenAI)", "backend": "gpt-4o-mini", "provider": "openai"}, + {"display": "πΊπΈ o1-mini (OpenAI)", "backend": "o1-mini", "provider": "openai"}, + {"display": "πΊπΈ o3-mini (OpenAI)", "backend": "o3-mini", "provider": "openai"}, + # Grok (xAI) + {"display": "π¦Ύ Grok 2 (xAI)", "backend": "grok-2", "provider": "grok"}, + {"display": "π¦Ύ Grok 3 (xAI)", "backend": "grok-3", "provider": "grok"}, + # Anthropic + {"display": "π§ Sonnet 4 (Anthropic)", "backend": "sonnet-4", "provider": "anthropic"}, + {"display": "π§ Sonnet 3.7 (Anthropic)", "backend": "sonnet-3.7", "provider": "anthropic"}, + # Gemini (Google) + {"display": "π· Gemini 2.5 Pro (Google)", "backend": "gemini-2.5-pro", "provider": "gemini"}, + {"display": "π· Gemini 2.5 Flash (Google)", "backend": "gemini-2.5-flash", "provider": "gemini"}, + {"display": "π· Gemini 2.5 Flash Lite Preview (Google)", "backend": "gemini-2.5-flash-lite-preview-06-17", "provider": "gemini"}, + {"display": "π· Gemini 2.0 Flash (Google)", "backend": "gemini-2.0-flash", "provider": "gemini"}, + {"display": "π· Gemini 2.0 Flash Preview Image Gen (Text+Image) (Google)", "backend": "gemini-2.0-flash-preview-image-generation", "provider": "gemini"}, + {"display": "π· Gemini 2.0 Flash Lite (Google)", "backend": "gemini-2.0-flash-lite", "provider": "gemini"}, +] + + +model_display_options = [m["display"] for m in models] + +# --- ErrorLLM and LocalLLM must be defined first --- +class ErrorLLM(LLM): + @property + def _llm_type(self) -> str: + return "error_llm" + def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: + return "Error: LLM pipeline could not be created. Please check your configuration and try again." + @property + def _identifying_params(self) -> dict: + return {} + +class LocalLLM(LLM): + @property + def _llm_type(self) -> str: + return "local_llm" + def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: + return "Local LLM Fallback Response" + @property + def _identifying_params(self) -> dict: + return {} + +# --- NEBIUS LLM Class --- +class NebiusLLM(LLM): temperature: float = 0.5 top_p: float = 0.95 + top_k: int = 50 + max_tokens: int = 3000 + model: str = "meta-llama/Meta-Llama-3.1-70B-Instruct" - def __init__(self, temperature: float = 0.5, top_p: float = 0.95): - super().__init__() + def __init__(self, model: str, temperature: float = 0.5, top_p: float = 0.95, top_k: int = 50, max_tokens: int = 3000, **kwargs: Any): + try: + from openai import OpenAI + except ImportError: + raise ImportError("openai package is required for NEBIUS models.") + super().__init__(**kwargs) + api_key = NEBIUS_API_KEY or os.environ.get("NEBIUS_API_KEY") + if not api_key: + raise ValueError("Please set the NEBIUS_API_KEY either in the code or as an environment variable.") + self.model = model self.temperature = temperature self.top_p = top_p + self.top_k = top_k + self.max_tokens = max_tokens + # Use object.__setattr__ to bypass Pydantic field validation + object.__setattr__(self, "_client", OpenAI(base_url="https://api.studio.nebius.com/v1/", api_key=api_key)) + + @property + def _llm_type(self) -> str: + return "nebius_llm" + + def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: + try: + completion = self._client.chat.completions.create( + model=self.model, + messages=[{"role": "user", "content": prompt}], + temperature=self.temperature, + top_p=self.top_p, + max_tokens=self.max_tokens + ) + return completion.choices[0].message.content if hasattr(completion.choices[0].message, 'content') else str(completion.choices[0].message) + except Exception as e: + return f"Error from NEBIUS: {str(e)}" + + @property + def _identifying_params(self) -> dict: + return {"model": self.model, "temperature": self.temperature, "top_p": self.top_p} + +# --- OpenAI LLM Class --- +class OpenAILLM(LLM): + temperature: float = 0.7 + top_p: float = 0.95 + top_k: int = 50 + max_tokens: int = 3000 + model: str = "gpt-3.5-turbo" + + def __init__(self, model: str, temperature: float = 0.7, top_p: float = 0.95, top_k: int = 50, max_tokens: int = 3000, **kwargs: Any): + import openai + super().__init__(**kwargs) + self.model = model + self.temperature = temperature + self.top_p = top_p + self.top_k = top_k + self.max_tokens = max_tokens + api_key = OPENAI_API_KEY or os.environ.get("OPENAI_API_KEY") + if not api_key: + raise ValueError("Please set the OPENAI_API_KEY either in the code or as an environment variable.") + openai.api_key = api_key + object.__setattr__(self, "_client", openai) @property def _llm_type(self) -> str: - return "remote_llm" + return "openai_llm" def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: try: - response = requests.post( - "http://localhost:11434/api/generate", - json={ - "model": "llama2", - "prompt": prompt, - "temperature": self.temperature, - "top_p": self.top_p - }, - stream=False - ) - if response.status_code == 200: - return response.json()["response"] + # Models with special parameter requirements + models_with_max_completion_tokens = ["o1-mini", "o3-mini", "gpt-4o", "gpt-4o-mini"] + o1o3_models = ["o1-mini", "o3-mini"] + + model_param = {} + if any(m in self.model for m in models_with_max_completion_tokens): + model_param["max_completion_tokens"] = self.max_tokens else: - return f"Error: {response.status_code}" + model_param["max_tokens"] = self.max_tokens + + kwargs = { + "model": self.model, + "messages": [{"role": "user", "content": prompt}], + **model_param + } + if any(m in self.model for m in o1o3_models): + kwargs["temperature"] = 1 + kwargs["top_p"] = 1 + else: + kwargs["temperature"] = self.temperature + kwargs["top_p"] = self.top_p + + completion = self._client.chat.completions.create(**kwargs) + return completion.choices[0].message.content if hasattr(completion.choices[0].message, 'content') else str(completion.choices[0].message) except Exception as e: - return f"Error: {str(e)}" + return f"Error from OpenAI: {str(e)}" @property def _identifying_params(self) -> dict: - return { - "temperature": self.temperature, - "top_p": self.top_p - } + return {"model": self.model, "temperature": self.temperature, "top_p": self.top_p} +# --- HuggingFace LLM Classes --- +class HuggingFaceLLM(LLM): + temperature: float = 0.5 + top_p: float = 0.95 + top_k: int = 50 + max_tokens: int = 3000 + model: str = "meta-llama/Meta-Llama-3-8B-Instruct" + + def __init__(self, model: str, temperature: float = 0.5, top_p: float = 0.95, top_k: int = 50, max_tokens: int = 3000, **kwargs: Any): + from huggingface_hub import InferenceClient + super().__init__(**kwargs) + self.model = model + self.temperature = temperature + self.top_p = top_p + self.top_k = top_k + self.max_tokens = max_tokens + hf_api_token = HF_API_TOKEN or os.environ.get("HF_API_TOKEN") + if not hf_api_token: + raise ValueError("Please set the HF_API_TOKEN either in the code or as an environment variable to use HuggingFace inference.") + # Use object.__setattr__ to bypass Pydantic field validation + object.__setattr__(self, "_client", InferenceClient(token=hf_api_token, timeout=120)) + + @property + def _llm_type(self) -> str: + return "hf_llm" + + def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: + try: + response = self._client.text_generation( + prompt, + model=self.model, + temperature=self.temperature, + top_p=self.top_p, + max_new_tokens=self.max_tokens + ) + return response + except Exception as e: + return f"Error from HuggingFace: {str(e)}" + + @property + def _identifying_params(self) -> dict: + return {"model": self.model, "temperature": self.temperature, "top_p": self.top_p} + +# --- Mistral LLM Class --- class MistralLLM(LLM): temperature: float = 0.7 top_p: float = 0.95 - _client: Any = PrivateAttr(default=None) + top_k: int = 50 + max_tokens: int = 3000 + model: str = "mistral-small-latest" + client: Any = None # Changed from _client PrivateAttr to avoid Pydantic issues - def __init__(self, api_key: str, temperature: float = 0.7, top_p: float = 0.95, **kwargs: Any): + def __init__(self, model: str, temperature: float = 0.7, top_p: float = 0.95, top_k: int = 50, max_tokens: int = 3000, **kwargs: Any): try: - super().__init__(**kwargs) - object.__setattr__(self, '_client', Mistral(api_key=api_key)) - self.temperature = temperature - self.top_p = top_p + from mistralai import Mistral + except ImportError as e: + raise ImportError(f"mistralai package is required for Mistral models. Please install with: pip install mistralai. Error: {e}") except Exception as e: - debug_print(f"Init Mistral failed with error: {e}") - + raise ImportError(f"Unexpected error importing mistralai: {e}") + super().__init__(**kwargs) + + # Check for API key + api_key = MISTRAL_API_KEY or os.environ.get("MISTRAL_API_KEY") + if not api_key: + debug_print("MISTRAL_API_KEY not found in code or environment variables") + raise ValueError("Please set the MISTRAL_API_KEY either in the code or as an environment variable.") + + debug_print(f"Initializing MistralLLM with model: {model}, API key: {api_key[:8]}...") + + self.model = model + self.temperature = temperature + self.top_p = top_p + self.top_k = top_k + self.max_tokens = max_tokens + + try: + # Initialize the client as a regular attribute instead of PrivateAttr + self.client = Mistral(api_key=api_key) + debug_print("Mistral client created successfully") + except Exception as e: + debug_print(f"Error creating Mistral client: {str(e)}") + raise RuntimeError(f"Failed to create Mistral client: {str(e)}") + @property def _llm_type(self) -> str: return "mistral_llm" def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: try: - debug_print("Calling Mistral API...") - response = self._client.chat.complete( - model="mistral-small-latest", + debug_print(f"Mistral API call: model={self.model}, temp={self.temperature}, top_p={self.top_p}, top_k={self.top_k}, max_tokens={self.max_tokens}") + response = self.client.chat.complete( # Use self.client instead of self._client + model=self.model, messages=[{"role": "user", "content": prompt}], temperature=self.temperature, - top_p=self.top_p + top_p=self.top_p, + max_tokens=self.max_tokens ) + debug_print(f"Mistral API response received successfully") return response.choices[0].message.content except Exception as e: debug_print(f"Mistral API error: {str(e)}") - return f"Error generating response: {str(e)}" + return f"Error from Mistral: {str(e)}" @property def _identifying_params(self) -> dict: - return {"model": "mistral-small-latest"} + return {"model": self.model, "temperature": self.temperature, "top_p": self.top_p, "top_k": self.top_k, "max_tokens": self.max_tokens} + +# --- Gemini LLM Class --- +class GeminiLLM(LLM): + temperature: float = 0.7 + top_p: float = 0.95 + max_tokens: int = 3000 + model: str = "gemini-2.5-flash" + + # RPM and RPD (requests per minute and per day) limits for Gemini models + GEMINI_LIMITS = { + "gemini-2.5-pro": {"rpm": 5, "rpd": 100}, + "gemini-2.5-flash": {"rpm": 10, "rpd": 250}, + "gemini-2.5-flash-lite-preview-06-17": {"rpm": 15, "rpd": 1000}, + "gemini-2.0-flash": {"rpm": 15, "rpd": 200}, + "gemini-2.0-flash-preview-image-generation": {"rpm": 15, "rpd": 200}, + "gemini-2.0-flash-lite": {"rpm": 30, "rpd": 200}, + } + + def __init__(self, model: str, temperature: float = 0.7, top_p: float = 0.95, max_tokens: int = 3000, **kwargs: Any): + try: + import google.generativeai as genai + except ImportError: + raise ImportError("google-generativeai package is required for Gemini models.") + super().__init__(**kwargs) + api_key = GEMINI_API_KEY or os.environ.get("GEMINI_API_KEY") + if not api_key: + raise ValueError("Please set the GEMINI_API_KEY either in the code or as an environment variable.") + self.model = model # Use backend string directly + self.temperature = temperature + self.top_p = top_p + self.max_tokens = max_tokens + genai.configure(api_key=api_key) + object.__setattr__(self, "_client", genai) + object.__setattr__(self, "_rpm_limit", self.GEMINI_LIMITS.get(model, {}).get("rpm", None)) + object.__setattr__(self, "_rpd_limit", self.GEMINI_LIMITS.get(model, {}).get("rpd", None)) + object.__setattr__(self, "_last_request_time", 0) -class LocalLLM(LLM): @property def _llm_type(self) -> str: - return "local_llm" + return "gemini_llm" def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: - truncated_prompt = truncate_prompt(prompt) - return f"Local LLM Fallback Response for: {truncated_prompt[:100]}..." + import time + import re + global GEMINI_LAST_REQUEST_TIME, GEMINI_DAILY_REQUESTS + model = self._client.GenerativeModel(self.model) + retries = 5 + # Daily RPD enforcement + today_str = datetime.datetime.now().strftime('%Y-%m-%d') + rpd_limit = object.__getattribute__(self, "_rpd_limit") if hasattr(self, "_rpd_limit") else None + count_info = GEMINI_DAILY_REQUESTS.get(self.model, (today_str, 0)) + if count_info[0] != today_str: + # New day, reset count + GEMINI_DAILY_REQUESTS[self.model] = (today_str, 0) + count_info = (today_str, 0) + if rpd_limit is not None and count_info[1] >= rpd_limit: + debug_print(f"Gemini: DAILY LIMIT REACHED for {self.model}: {count_info[1]}/{rpd_limit}") + return f"Error from Gemini: Daily request limit reached for {self.model} ({rpd_limit} per day)" + for attempt in range(retries): + # Strict RPM enforcement: global per-model + rpm_limit = object.__getattribute__(self, "_rpm_limit") if hasattr(self, "_rpm_limit") else None + if rpm_limit: + now = time.time() + min_interval = 60.0 / rpm_limit + last_time = GEMINI_LAST_REQUEST_TIME.get(self.model, 0) + elapsed = now - last_time + if elapsed < min_interval: + sleep_time = min_interval - elapsed + debug_print(f"Gemini: Sleeping {sleep_time:.2f}s to respect RPM limit for {self.model}") + time.sleep(sleep_time) + try: + response = model.generate_content(prompt, generation_config={ + "temperature": self.temperature, + "top_p": self.top_p, + "max_output_tokens": self.max_tokens + }) + now = time.time() + GEMINI_LAST_REQUEST_TIME[self.model] = now + object.__setattr__(self, "_last_request_time", now) + # Increment daily request count + count_info = GEMINI_DAILY_REQUESTS.get(self.model, (today_str, 0)) + GEMINI_DAILY_REQUESTS[self.model] = (today_str, count_info[1] + 1) + rpd_limit = object.__getattribute__(self, "_rpd_limit") if hasattr(self, "_rpd_limit") else None + debug_print(f"Gemini: {self.model} daily usage: {GEMINI_DAILY_REQUESTS[self.model][1]}/{rpd_limit}") + return response.text if hasattr(response, 'text') else str(response) + except Exception as e: + msg = str(e) + debug_print(f"Gemini error: {msg}") + # Check for any 429 error and always extract retry_delay + if "429" in msg: + retry_delay = None + match = re.search(r'retry_delay\s*{\s*seconds:\s*(\d+)', msg) + if match: + retry_delay = int(match.group(1)) + sleep_time = retry_delay + 2 + debug_print(f"Gemini: 429 received, sleeping for retry_delay {retry_delay}s + 2s buffer (total {sleep_time}s)") + time.sleep(sleep_time) + continue + # If retry_delay is present but empty, sleep for 3 seconds and retry + elif 'retry_delay' in msg: + debug_print(f"Gemini: 429 received, empty retry_delay, sleeping for 3s and retrying") + time.sleep(3) + continue + else: + debug_print(f"Gemini: 429 received, but no retry_delay found. Returning error.") + return f"Error from Gemini: {msg}" + # For all other errors, do not retry + return f"Error from Gemini: {msg}" @property def _identifying_params(self) -> dict: - return {} + return {"model": self.model, "temperature": self.temperature, "top_p": self.top_p} + +# --- Grok LLM Class --- +class GrokLLM(LLM): + temperature: float = 0.7 + top_p: float = 0.95 + max_tokens: int = 3000 + model: str = "grok-2" + + def __init__(self, model: str, temperature: float = 0.7, top_p: float = 0.95, max_tokens: int = 3000, **kwargs: Any): + import requests + super().__init__(**kwargs) + api_key = GROK_API_KEY or os.environ.get("GROK_API_KEY") + if not api_key: + raise ValueError("Please set the GROK_API_KEY either in the code or as an environment variable.") + self.model = model + self.temperature = temperature + self.top_p = top_p + self.max_tokens = max_tokens + object.__setattr__(self, "_api_key", api_key) -class ErrorLLM(LLM): @property def _llm_type(self) -> str: - return "error_llm" - + return "grok_llm" + def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: - return "Error: LLM pipeline could not be created. Please check your configuration and try again." - + import requests + headers = { + "Authorization": f"Bearer {self._api_key}", + "Content-Type": "application/json" + } + data = { + "model": self.model, + "messages": [{"role": "user", "content": prompt}], + "temperature": self.temperature, + "top_p": self.top_p, + "max_tokens": self.max_tokens + } + try: + response = requests.post("https://api.x.ai/v1/chat/completions", headers=headers, json=data, timeout=60) + response.raise_for_status() + result = response.json() + return result["choices"][0]["message"]["content"] + except Exception as e: + return f"Error from Grok: {str(e)}" + @property def _identifying_params(self) -> dict: - return {} + return {"model": self.model, "temperature": self.temperature, "top_p": self.top_p} + +# --- Anthropic LLM Class --- +class AnthropicLLM(LLM): + temperature: float = 0.7 + top_p: float = 0.95 + max_tokens: int = 3000 + model: str = "claude-sonnet-4-20250514" + + def __init__(self, model: str, temperature: float = 0.7, top_p: float = 0.95, max_tokens: int = 3000, **kwargs: Any): + try: + import anthropic + except ImportError: + raise ImportError("anthropic package is required for Anthropic models.") + + super().__init__(**kwargs) + + api_key = ANTHROPIC_API_KEY or os.environ.get("ANTHROPIC_API_KEY") + if not api_key: + raise ValueError("Please set the ANTHROPIC_API_KEY either in the code or as an environment variable.") + + # Map display/backend names to supported API model names + model_map = { + "sonnet-4": "claude-sonnet-4-20250514", + "sonnet-3.7": "claude-3-7-sonnet-20250219", + } + self.model = model_map.get(model, model) + self.temperature = temperature + self.top_p = top_p + self.max_tokens = max_tokens + + # Correct initialization - use anthropic.Anthropic(), not anthropic.Client() + object.__setattr__(self, "_client", anthropic.Anthropic(api_key=api_key)) + @property + def _llm_type(self) -> str: + return "anthropic_llm" + + def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: + try: + response = self._client.messages.create( + model=self.model, + max_tokens=self.max_tokens, + messages=[{"role": "user", "content": prompt}], + temperature=self.temperature, + top_p=self.top_p + ) + + # Extract text content from the response + if hasattr(response, 'content') and response.content: + if isinstance(response.content, list): + # Handle list of content blocks + text_content = "" + for content_block in response.content: + if hasattr(content_block, 'text'): + text_content += content_block.text + elif isinstance(content_block, dict) and 'text' in content_block: + text_content += content_block['text'] + return text_content + else: + return str(response.content) + + return str(response) + + except Exception as e: + return f"Error from Anthropic: {str(e)}" + + @property + def _identifying_params(self) -> dict: + return {"model": self.model, "temperature": self.temperature, "top_p": self.top_p} + +# --- Update SimpleLLMChain to support all providers --- class SimpleLLMChain: - def __init__(self, llm_choice: str = "Meta-Llama-3", - temperature: float = 0.5, - top_p: float = 0.95) -> None: + def __init__(self, llm_choice: str = model_display_options[0], temperature: float = 0.5, top_p: float = 0.95, top_k: int = 50, max_tokens: int = 3000): self.llm_choice = llm_choice self.temperature = temperature self.top_p = top_p + self.top_k = top_k + self.max_tokens = max_tokens self.llm = self.create_llm_pipeline() - self.conversation_history = [] # Keep track of conversation - + self.conversation_history = [] + def create_llm_pipeline(self): - from langchain.llms.base import LLM # Import LLM here so it's always defined - normalized = self.llm_choice.lower() + # Find the model entry + model_entry = next((m for m in models if m["display"] == self.llm_choice), None) + if not model_entry: + return ErrorLLM() + provider = model_entry["provider"] + backend = model_entry["backend"] try: - if "remote" in normalized: - debug_print("Creating remote Meta-Llama-3 pipeline via Hugging Face Inference API...") - from huggingface_hub import InferenceClient - repo_id = "meta-llama/Meta-Llama-3-8B-Instruct" - hf_api_token = os.environ.get("HF_API_TOKEN") - if not hf_api_token: - raise ValueError("Please set the HF_API_TOKEN environment variable to use remote inference.") - - client = InferenceClient(token=hf_api_token, timeout=120) - - # We no longer use wait_for_model because it's unsupported - def remote_generate(prompt: str) -> str: - max_retries = 3 - backoff = 2 # start with 2 seconds - for attempt in range(max_retries): - try: - debug_print(f"Remote generation attempt {attempt+1}") - response = client.text_generation( - prompt, - model=repo_id, - temperature=self.temperature, - top_p=self.top_p, - max_new_tokens=512 # Reduced token count for speed - ) - return response - except Exception as e: - debug_print(f"Attempt {attempt+1} failed with error: {e}") - if attempt == max_retries - 1: - raise - time.sleep(backoff) - backoff *= 2 # exponential backoff - return "Failed to generate response after multiple attempts." - - class RemoteLLM(LLM): - @property - def _llm_type(self) -> str: - return "remote_llm" - - def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: - return remote_generate(prompt) - - @property - def _identifying_params(self) -> dict: - return {"model": repo_id} - - debug_print("Remote Meta-Llama-3 pipeline created successfully.") - return RemoteLLM() - - elif "mistral" in normalized: - api_key = os.getenv("MISTRAL_API_KEY") - return MistralLLM(api_key=api_key, temperature=self.temperature, top_p=self.top_p) + if provider == "nebius": + return NebiusLLM(model=backend, temperature=self.temperature, top_p=self.top_p, top_k=self.top_k, max_tokens=self.max_tokens) + elif provider == "openai": + return OpenAILLM(model=backend, temperature=self.temperature, top_p=self.top_p, top_k=self.top_k, max_tokens=self.max_tokens) + elif provider == "hf_inference": + return HuggingFaceLLM(model=backend, temperature=self.temperature, top_p=self.top_p, top_k=self.top_k, max_tokens=self.max_tokens) + elif provider == "mistral": + return MistralLLM(model=backend, temperature=self.temperature, top_p=self.top_p, top_k=self.top_k, max_tokens=self.max_tokens) + elif provider == "gemini": + return GeminiLLM(model=backend, temperature=self.temperature, top_p=self.top_p, max_tokens=self.max_tokens) + elif provider == "grok": + return GrokLLM(model=backend, temperature=self.temperature, top_p=self.top_p, max_tokens=self.max_tokens) + elif provider == "anthropic": + return AnthropicLLM(model=backend, temperature=self.temperature, top_p=self.top_p, max_tokens=self.max_tokens) else: return LocalLLM() except Exception as e: debug_print(f"Error creating LLM pipeline: {str(e)}") return ErrorLLM() - def update_llm_pipeline(self, new_model_choice: str, temperature: float, top_p: float): + def update_llm_pipeline(self, new_model_choice: str, temperature: float, top_p: float, top_k: int, max_tokens: int): self.llm_choice = new_model_choice self.temperature = temperature self.top_p = top_p + self.top_k = top_k + self.max_tokens = max_tokens self.llm = self.create_llm_pipeline() def submit_query(self, query: str) -> tuple: @@ -603,7 +1181,7 @@ class SimpleLLMChain: return (f"Error processing query: {str(e)}", "Input tokens: 0", "Output tokens: 0") # Update submit_query_updated to work with the simplified chain -def submit_query_updated(query: str, model_choice: str = None, temperature: float = 0.5, top_p: float = 0.95): +def submit_query_updated(query: str, model_choice: str = None, temperature: float = 0.5, top_p: float = 0.95, top_k: int = 50, max_tokens: int = 3000): """Process a query with the specified model and parameters.""" debug_print(f"Processing query: {query}") if not query: @@ -616,10 +1194,12 @@ def submit_query_updated(query: str, model_choice: str = None, temperature: floa llm_chain = SimpleLLMChain( llm_choice=model_choice, temperature=temperature, - top_p=top_p + top_p=top_p, + top_k=top_k, + max_tokens=max_tokens ) elif llm_chain.llm_choice != model_choice: - llm_chain.update_llm_pipeline(model_choice, temperature, top_p) + llm_chain.update_llm_pipeline(model_choice, temperature, top_p, top_k, max_tokens) response, input_tokens, output_tokens = llm_chain.submit_query(query) return response, "", input_tokens, output_tokens @@ -704,6 +1284,535 @@ def reset_app_updated(): "Model used: Not selected" ) +# Batch query function + +error_patterns = [ + r"error generating response:", + r"api error occurred:", + r"bad gateway", + r"cloudflare", + r"server disconnected without sending a response", + r"getaddrinfo failed" +] + +# Batch query function + +def run_batch_query(query, model1, temperature, top_p, top_k, max_tokens, num_runs, delay_ms, prefix=None): + import re + num_runs = int(num_runs) + delay_ms = int(delay_ms) + results = [] + error_count = 0 + token_counts = [] + outputs = [] + model_name = model1 + # Sanitize prefix and model name for filenames + def sanitize(s): + return re.sub(r'[^A-Za-z0-9_-]+', '', str(s).replace(' ', '_')) + safe_prefix = sanitize(prefix) if prefix else '' + safe_model = sanitize(model_name) + date_str = datetime.datetime.now().strftime('%Y%m%d_%H%M%S') + for i in range(num_runs): + attempt = 0 + max_attempts = 5 + while attempt < max_attempts: + response, _, input_tokens, output_tokens = submit_query_updated(query, model1, temperature, top_p, top_k, max_tokens) + output = response if isinstance(response, str) else str(response) + if any(re.search(pat, output, re.IGNORECASE) for pat in error_patterns): + error_count += 1 + attempt += 1 + time.sleep((delay_ms/1000.0) * (attempt+1)) + continue + else: + break + try: + token_num = 0 + if output_tokens is not None: + try: + last_token = output_tokens.split()[-1] if isinstance(output_tokens, str) else str(output_tokens) + if last_token.isdigit(): + token_num = int(last_token) + except Exception as e: + debug_print(f"Token count conversion failed for output_tokens={output_tokens}: {e}") + else: + token_num = 0 + except Exception as e: + debug_print(f"Token count conversion outer exception for output_tokens={output_tokens}: {e}") + token_num = 0 + token_counts.append(token_num) + results.append({ + 'run': i+1, + 'output': output, + 'input_tokens': input_tokens, + 'output_tokens': output_tokens, + 'tokens': token_num, + 'error': attempt if attempt > 0 else 0 + }) + outputs.append(f"=== Query {i+1}/{num_runs} ===\nTokens: {token_num}\n{output}") + time.sleep(delay_ms/1000.0) + # Save to CSV + filename = f"{safe_prefix + '-' if safe_prefix else ''}{num_runs}_{safe_model}_{date_str}.csv" + abs_csv_path = os.path.abspath(filename) + with open(abs_csv_path, 'w', newline='', encoding='utf-8') as csvfile: + writer = csv.writer(csvfile) + writer.writerow(['Run', 'Output', 'Input Tokens', 'Output Tokens', 'Tokens', 'Error Retries']) + for r in results: + writer.writerow([r['run'], r['output'], r['input_tokens'], r['output_tokens'], r['tokens'], r['error']]) + # Stats + total_tokens = sum(token_counts) + avg_tokens = statistics.mean(token_counts) if token_counts else 0 + stdev_tokens = statistics.stdev(token_counts) if len(token_counts) > 1 else 0 + stats = f"Total queries: {num_runs}\nTotal tokens: {total_tokens}\nAverage tokens: {avg_tokens:.2f}\nSTDEV tokens: {stdev_tokens:.2f}\nErrors encountered: {error_count}" + output_text = f"Model: {model_name}\n\n" + '\n\n'.join(outputs) + return output_text, abs_csv_path, stats + +# Async batch job submission + +def submit_batch_query_async(prefix, query, prompt_mode, model, temperature, top_p, top_k, max_tokens, num_runs, delay_ms): + global last_job_id + if not query: + return ("Please enter a non-empty query", "", "", get_job_list()) + job_id = str(uuid.uuid4()) + debug_print(f"Starting async batch job {job_id} for batch query") + threading.Thread( + target=process_in_background, + args=(job_id, process_batch_query_job, [job_id, prefix, query, "All at Once", model, temperature, top_p, top_k, max_tokens, num_runs, delay_ms]) + ).start() + jobs[job_id] = { + "status": "processing", + "type": "batch_query", + "start_time": time.time(), + "query": query, + "model": model, + "params": { + "prefix": prefix, + "prompt_mode": prompt_mode, + "temperature": temperature, + "top_p": top_p, + "top_k": top_k, + "max_tokens": max_tokens, + "num_runs": num_runs, + "delay_ms": delay_ms + } + } + last_job_id = job_id + return ( + f"Batch job submitted and processing in the background (Job ID: {job_id}).\n\nUse 'Check Job Status' tab with this ID to get results.", + job_id, + query, + get_job_list() + ) + +def process_batch_query_job(job_id, prefix, query, prompt_mode, model, temperature, top_p, top_k, max_tokens, num_runs, delay_ms): + import statistics + import os + num_runs = int(num_runs) + delay_ms = int(delay_ms) + results = [] + error_count = 0 + token_counts = [] + outputs = [] + model_name = model + query_times = [] + batch_start = time.time() + # Sanitize prefix and model name for filenames + def sanitize(s): + import re + return re.sub(r'[^A-Za-z0-9_-]+', '', str(s).replace(' ', '_')) + safe_prefix = sanitize(prefix) if prefix else '' + safe_model = sanitize(model_name) + date_str = datetime.datetime.now().strftime('%Y%m%d_%H%M%S') + if prompt_mode == "Sequential Prompts": + # Parse the query to extract instruction and individual prompts + lines = query.strip().split('\n') + if len(lines) < 2: + debug_print("Sequential mode requires at least 2 lines: instruction + at least one prompt") + return "Error: Sequential mode requires at least 2 lines (instruction + prompts)", "", "", "" + + instruction = lines[0].strip() + individual_prompts = [line.strip() for line in lines[1:] if line.strip()] + + debug_print(f"Sequential mode: instruction='{instruction}', {len(individual_prompts)} prompts") + + for i, prompt in enumerate(individual_prompts): + # For each prompt, run it multiple times based on num_runs + for run_num in range(num_runs): + # Combine instruction with individual prompt + full_prompt = f"{instruction}\n\n{prompt}" + + attempt = 0 + max_attempts = 5 + start = time.time() + while attempt < max_attempts: + response, _, input_tokens, output_tokens = submit_query_updated(full_prompt, model, temperature, top_p, top_k, max_tokens) + output = response if isinstance(response, str) else str(response) + if any(re.search(pat, output, re.IGNORECASE) for pat in error_patterns): + error_count += 1 + attempt += 1 + time.sleep((delay_ms/1000.0) * (attempt+1)) + continue + else: + break + end = time.time() + elapsed = end - start + query_times.append(elapsed) + + try: + token_num = 0 + if output_tokens is not None: + try: + last_token = output_tokens.split()[-1] if isinstance(output_tokens, str) else str(output_tokens) + if last_token.isdigit(): + token_num = int(last_token) + except Exception as e: + debug_print(f"Token count conversion failed for output_tokens={output_tokens}: {e}") + else: + token_num = 0 + except Exception as e: + debug_print(f"Token count conversion outer exception for output_tokens={output_tokens}: {e}") + token_num = 0 + + token_counts.append(token_num) + results.append({ + 'prompt_number': i+1, + 'run': run_num+1, + 'input_prompt': prompt, + 'full_prompt': full_prompt, + 'output': output, + 'input_tokens': input_tokens, + 'output_tokens': output_tokens, + 'tokens': token_num, + 'error': attempt if attempt > 0 else 0, + 'time': elapsed + }) + outputs.append(f"=== Prompt {i+1}/{len(individual_prompts)} - Run {run_num+1}/{num_runs} ===\nInput: {prompt}\nTokens: {token_num}\nOutput: {output}") + + # --- Update partial_results for live progress --- + total_processed = i * num_runs + run_num + 1 + total_to_process = len(individual_prompts) * num_runs + jobs[job_id]["partial_results"] = { + "num_done": total_processed, + "total": total_to_process, + "avg_time": statistics.mean(query_times) if query_times else 0, + "stdev_time": statistics.stdev(query_times) if len(query_times) > 1 else 0, + "total_tokens": sum(token_counts), + "avg_tokens": statistics.mean(token_counts) if token_counts else 0, + "stdev_tokens": statistics.stdev(token_counts) if len(token_counts) > 1 else 0, + "errors": error_count, + } + time.sleep(delay_ms/1000.0) + else: + # Original "All at Once" logic + for i in range(num_runs): + attempt = 0 + max_attempts = 5 + start = time.time() + while attempt < max_attempts: + response, _, input_tokens, output_tokens = submit_query_updated(query, model, temperature, top_p, top_k, max_tokens) + output = response if isinstance(response, str) else str(response) + if any(re.search(pat, output, re.IGNORECASE) for pat in error_patterns): + error_count += 1 + attempt += 1 + time.sleep((delay_ms/1000.0) * (attempt+1)) + continue + else: + break + end = time.time() + elapsed = end - start + query_times.append(elapsed) + try: + token_num = 0 + if output_tokens is not None: + try: + last_token = output_tokens.split()[-1] if isinstance(output_tokens, str) else str(output_tokens) + if last_token.isdigit(): + token_num = int(last_token) + except Exception as e: + debug_print(f"Token count conversion failed for output_tokens={output_tokens}: {e}") + else: + token_num = 0 + except Exception as e: + debug_print(f"Token count conversion outer exception for output_tokens={output_tokens}: {e}") + token_num = 0 + token_counts.append(token_num) + results.append({ + 'run': i+1, + 'output': output, + 'input_tokens': input_tokens, + 'output_tokens': output_tokens, + 'tokens': token_num, + 'error': attempt if attempt > 0 else 0, + 'time': elapsed + }) + outputs.append(f"=== Query {i+1}/{num_runs} ===\nTokens: {token_num}\n{output}") + # --- Update partial_results for live progress --- + jobs[job_id]["partial_results"] = { + "num_done": i+1, + "total": num_runs, + "avg_time": statistics.mean(query_times) if query_times else 0, + "stdev_time": statistics.stdev(query_times) if len(query_times) > 1 else 0, + "total_tokens": sum(token_counts), + "avg_tokens": statistics.mean(token_counts) if token_counts else 0, + "stdev_tokens": statistics.stdev(token_counts) if len(token_counts) > 1 else 0, + "errors": error_count, + } + time.sleep(delay_ms/1000.0) + batch_end = time.time() + total_time = batch_end - batch_start + avg_time = statistics.mean(query_times) if query_times else 0 + stdev_time = statistics.stdev(query_times) if len(query_times) > 1 else 0 + # Save to CSV + if prompt_mode == "Sequential Prompts": + filename = f"{safe_prefix + '-' if safe_prefix else ''}sequential-{safe_model}_{date_str}.csv" + abs_csv_path = os.path.abspath(filename) + with open(abs_csv_path, 'w', newline='', encoding='utf-8') as csvfile: + writer = csv.writer(csvfile) + writer.writerow(['Prompt Number', 'Run', 'Input Prompt', 'Full Prompt', 'Output', 'Input Tokens', 'Output Tokens', 'Tokens', 'Error Retries', 'Time (s)']) + for r in results: + writer.writerow([ + r['prompt_number'], + r['run'], + r['input_prompt'], + r['full_prompt'], + r['output'], + r['input_tokens'], + r['output_tokens'], + r['tokens'], + r['error'], + f"{r['time']:.3f}" + ]) + else: + filename = f"{safe_prefix + '-' if safe_prefix else ''}{num_runs}-{safe_model}_{date_str}.csv" + abs_csv_path = os.path.abspath(filename) + with open(abs_csv_path, 'w', newline='', encoding='utf-8') as csvfile: + writer = csv.writer(csvfile) + writer.writerow(['Run', 'Output', 'Input Tokens', 'Output Tokens', 'Tokens', 'Error Retries', 'Time (s)']) + for r in results: + writer.writerow([r['run'], r['output'], r['input_tokens'], r['output_tokens'], r['tokens'], r['error'], f"{r['time']:.3f}"]) + # Save prompt TXT file + txt_filename = f"{safe_prefix}-{num_runs}-{1}_LLMs_prompt_{date_str}.TXT" + abs_txt_path = os.path.abspath(txt_filename) + with open(abs_txt_path, 'w', encoding='utf-8') as txtfile: + txtfile.write(query) + # Stats + total_tokens = sum(token_counts) + avg_tokens = statistics.mean(token_counts) if token_counts else 0 + stdev_tokens = statistics.stdev(token_counts) if len(token_counts) > 1 else 0 + + if prompt_mode == "Sequential Prompts": + total_prompts = len(individual_prompts) + total_runs = total_prompts * num_runs + stats = ( + f"Prompt mode: {prompt_mode}\n" + f"Total prompts: {total_prompts}\n" + f"Runs per prompt: {num_runs}\n" + f"Total runs: {total_runs}\n" + f"Total tokens: {total_tokens}\n" + f"Average tokens: {avg_tokens:.2f}\n" + f"STDEV tokens: {stdev_tokens:.2f}\n" + f"Errors encountered: {error_count}\n" + f"Total time elapsed: {total_time:.2f} s\n" + f"Average time per run: {avg_time:.2f} s\n" + f"STD time per run: {stdev_time:.2f} s" + ) + else: + stats = ( + f"Prompt mode: {prompt_mode}\n" + f"Total queries: {num_runs}\n" + f"Total tokens: {total_tokens}\n" + f"Average tokens: {avg_tokens:.2f}\n" + f"STDEV tokens: {stdev_tokens:.2f}\n" + f"Errors encountered: {error_count}\n" + f"Total time elapsed: {total_time:.2f} s\n" + f"Average time per query: {avg_time:.2f} s\n" + f"STD time per query: {stdev_time:.2f} s" + ) + + output_text = f"Model: {model_name}\n\n" + '\n\n'.join(outputs) + return output_text, abs_csv_path, stats, abs_txt_path + +def check_batch_job_status(job_id): + # Use same logic as check_job_status, but for batch jobs + try: + while not results_queue.empty(): + completed_id, result = results_queue.get_nowait() + if completed_id in jobs: + jobs[completed_id]["status"] = "completed" + jobs[completed_id]["result"] = result + jobs[completed_id]["end_time"] = time.time() + debug_print(f"Job {completed_id} completed and stored in jobs dictionary") + except queue.Empty: + pass + if job_id not in jobs: + # Always return 9 outputs + return ("Job not found. Please check the ID and try again.", "", "", "", "", "", "", "", "") + job = jobs[job_id] + # If this is a ZIP job and all sub-jobs are completed, create the ZIP + if job.get("output_format") == "ZIP" and job.get("zip_job_ids"): + all_done = all(jobs[jid]["status"] == "completed" for jid in job["zip_job_ids"]) + if all_done and not job.get("zip_created"): + # Collect all CSV paths and TXT prompt files + csv_paths = [] + txt_paths = [] + for jid in job["zip_job_ids"]: + result = jobs[jid]["result"] + if isinstance(result, (list, tuple)) and len(result) > 1: + csv_paths.append(result[1]) + if isinstance(result, (list, tuple)) and len(result) > 3: + txt_paths.append(result[3]) + # Create ZIP with new naming convention + prefix = job.get("params", {}).get("prefix", "batch") + num_runs = job.get("params", {}).get("num_runs", len(job["zip_job_ids"])) + num_llms = len(job["zip_job_ids"]) + date_str = datetime.datetime.now().strftime('%Y%m%d_%H%M%S') + def sanitize(s): + import re + return re.sub(r'[^A-Za-z0-9_-]+', '', str(s).replace(' ', '_')) + safe_prefix = sanitize(prefix) if prefix else 'batch' + zip_name = f"{safe_prefix}-{num_runs}_{num_llms}_LLMs_{date_str}.zip" + with zipfile.ZipFile(zip_name, 'w') as zipf: + for csv_path in csv_paths: + zipf.write(csv_path, os.path.basename(csv_path)) + for txt_path in txt_paths: + zipf.write(txt_path, os.path.basename(txt_path)) + job["zip_created"] = True + job["zip_path"] = os.path.abspath(zip_name) + if job.get("zip_created"): + return (f"ZIP archive created: {os.path.basename(job['zip_path'])}", job["zip_path"], "ZIP archive ready.", job.get("query", ""), "", "", "", "", "") + else: + # Show progress info for ZIP jobs just like normal batch jobs + # Aggregate progress from all sub-jobs + num_total = len(job["zip_job_ids"]) + num_done = sum(1 for jid in job["zip_job_ids"] if jobs[jid]["status"] == "completed") + # Optionally, aggregate stats + total_tokens = 0 + errors = 0 + for jid in job["zip_job_ids"]: + j = jobs[jid] + if j["status"] == "completed": + result = j.get("result", ("", "", "")) + stats = result[2] if len(result) > 2 else "" + if stats: + for line in stats.split('\n'): + if line.lower().startswith("total tokens"): + try: + total_tokens += int(line.split(":",1)[1].strip()) + except: pass + if line.lower().startswith("errors encountered"): + try: + errors += int(line.split(":",1)[1].strip()) + except: pass + temp_stats = f"Batch ZIP job is being processed.\nJobs completed: {num_done} out of {num_total}\nTotal tokens so far: {total_tokens}\nErrors encountered: {errors}\n\nZIP will be created when all jobs are done." + return (temp_stats, "", "", job.get("query", ""), "", "", "", "", "") + if job["status"] == "processing": + elapsed_time = time.time() - job["start_time"] + # Try to show temporary stats if available + temp_stats = f"Batch job is still being processed (elapsed: {elapsed_time:.1f}s).\n" + # If partial results are available, show them + if "partial_results" in job: + partial = job["partial_results"] + num_done = partial.get("num_done", 0) + total = partial.get("total", "?") + + # Get job parameters for better description + job_params = job.get("params", {}) + prompt_mode = job_params.get("prompt_mode", "All at Once") + num_runs = job_params.get("num_runs", "?") + + # Create more descriptive progress message + if prompt_mode == "Sequential Prompts" and total != "?" and num_runs != "?": + # Calculate number of prompts from total and runs + num_prompts = total // num_runs if total != "?" and num_runs != "?" else "?" + temp_stats += f"Progress: {num_done} out of {total} total runs\n" + temp_stats += f"({num_prompts} prompts Γ {num_runs} runs each)\n" + else: + temp_stats += f"Queries run: {num_done} out of {total}\n" + + avg_time = partial.get("avg_time", None) + stdev_time = partial.get("stdev_time", None) + total_tokens = partial.get("total_tokens", None) + avg_tokens = partial.get("avg_tokens", None) + stdev_tokens = partial.get("stdev_tokens", None) + errors = partial.get("errors", None) + if avg_time is not None and stdev_time is not None: + temp_stats += f"Average time per query: {avg_time}\nSTDEV time: {stdev_time}\n" + if total_tokens is not None: + temp_stats += f"Total tokens: {total_tokens}\n" + if avg_tokens is not None: + temp_stats += f"Average tokens: {avg_tokens}\n" + if stdev_tokens is not None: + temp_stats += f"STDEV tokens: {stdev_tokens}\n" + if errors is not None: + temp_stats += f"Errors encountered: {errors}\n" + else: + # If no partials, show total planned queries with better description + job_params = job.get("params", {}) + prompt_mode = job_params.get("prompt_mode", "All at Once") + num_runs = job_params.get("num_runs", "?") + + if prompt_mode == "Sequential Prompts": + # For sequential prompts, we need to know the number of prompts + # This will be available once processing starts + temp_stats += f"Starting sequential prompts processing...\n" + temp_stats += f"Will run {num_runs} times per prompt\n" + else: + temp_stats += f"Starting batch processing...\n" + temp_stats += f"Will run {num_runs} times\n" + temp_stats += "\nTry checking again in a few seconds." + return ( + temp_stats, + "", + "", + job.get("query", ""), + "", + "", + "", + "", + "" + ) + if job["status"] == "completed": + result = job["result"] + # Defensive unpack: only take first 3 elements if more are present + if isinstance(result, (list, tuple)): + output_text, abs_csv_path, stats, abs_txt_path = result[:4] if len(result) >= 4 else (result + ("",) * (4 - len(result))) + else: + output_text, abs_csv_path, stats, abs_txt_path = result, "", "", "" + # Parse stats for details + stats_dict = {} + stats_lines = stats.split('\n') if stats else [] + for line in stats_lines: + if ':' in line: + k, v = line.split(':', 1) + stats_dict[k.strip().lower()] = v.strip() + # Timing info + elapsed = job.get("end_time", 0) - job.get("start_time", 0) + # Try to extract number of queries run + total_queries = stats_dict.get("total queries", "?") + # Try to extract average and stdev time if present + avg_time = stats_dict.get("average time per query", None) + stdev_time = stats_dict.get("std time per query", None) + # Compose enhanced header + header = f"Elapsed time: {elapsed:.2f}s\n" + header += f"Queries run: {total_queries} out of {total_queries}\n" if total_queries != "?" else "" + if avg_time and stdev_time: + header += f"Average time per query: {avg_time}\nSTDEV time: {stdev_time}\n" + # Add token and error stats if present + for k in ["total tokens", "average tokens", "stdev tokens", "errors encountered"]: + if k in stats_dict: + header += f"{k.title()}: {stats_dict[k]}\n" + # Add a separator + header += "\n---\n" + # Show header + per-query outputs (restore output_text here) + return header + output_text, abs_csv_path, header + output_text, job.get("query", ""), "", "", "", "", "" + # Always return 9 outputs + return (f"Job status: {job['status']}", "", "", job.get("query", ""), "", "", "", "", "") + +# Gradio download helper + +def download_csv(csv_path): + with open(csv_path, 'rb') as f: + return f.read(), csv_path + # ---------------------------- # Gradio Interface Setup # ---------------------------- @@ -741,17 +1850,50 @@ def periodic_update(is_checked): if is_checked: global last_job_id job_list_md = refresh_job_list() - job_status = check_job_status(last_job_id) if last_job_id else ("No job ID available", "", "", "", "") + job_status = check_job_status(last_job_id) if last_job_id else ("No job ID available", "", "", "", "", "", "", "", "") query_results = run_query(10) # Use a fixed value or another logic if needed - return job_list_md, job_status[0], query_results, "" # Return empty string instead of context + # Also update model responses + model1_resp, model1_tok, model2_resp, model2_tok = update_model_responses_from_jobs() + return job_list_md, job_status[0], query_results, "", model1_resp, model1_tok, model2_resp, model2_tok, "", "", "" else: # Return empty values to stop updates - return "", "", [], "" + return "", "", [], "", "", "", "", "", "", "", "" # Define a function to determine the interval based on the checkbox state def get_interval(is_checked): return 2 if is_checked else None +# 1. Utility function to list all CSV files in the workspace +import glob + +def list_all_csv_files(): + csv_files = sorted(glob.glob("*.csv"), key=os.path.getmtime, reverse=True) + zip_files = sorted(glob.glob("*.zip"), key=os.path.getmtime, reverse=True) + all_files = csv_files + zip_files + if not all_files: + return "No CSV or ZIP files found.", [], [] + # Gather file info: name, date/time, size + file_infos = [] + for f in all_files: + stat = os.stat(f) + dt = datetime.datetime.fromtimestamp(stat.st_mtime).strftime('%Y-%m-%d %H:%M:%S') + size_kb = stat.st_size / 1024 + file_infos.append({ + "name": os.path.basename(f), + "path": os.path.abspath(f), + "datetime": dt, + "size_kb": f"{size_kb:.1f} KB" + }) + # HTML table with columns: Name, Date/Time, Size + html_links = '
| File | Date/Time | Size |
|---|---|---|
| {info["name"]} | ' \ + f'{info["datetime"]} | {info["size_kb"]} |