Spaces:
Running
Running
| import gradio as gr | |
| import google.generativeai as genai | |
| from groq import Groq | |
| from ragatouille import RAGPretrainedModel | |
| import arxiv | |
| import os | |
| import re | |
| from datetime import datetime | |
| from utils import get_md_text_abstract | |
| from huggingface_hub import snapshot_download | |
| # --- Core Configuration --- | |
| hf_token = os.getenv("HF_TOKEN") | |
| gemini_api_key = os.getenv("GEMINI_API_KEY") | |
| groq_api_key = os.getenv("GROQ_API_KEY") | |
| RAG_SOURCE = os.getenv("RAG_SOURCE") | |
| LOCAL_DATA_DIR = os.getenv("LOCAL_DATA_DIR", "./rag_index_data") | |
| LLM_MODELS_TO_CHOOSE = [ | |
| "groq:llama-3.1-8b-instant", | |
| "groq:llama-3.3-70b-versatile", | |
| "gemini:gemma-4-26b-a4b-it", | |
| "gemini:gemma-4-31b-it", | |
| "None", | |
| ] | |
| DEFAULT_LLM_MODEL = "groq:llama-3.1-8b-instant" | |
| RETRIEVE_RESULTS = 20 | |
| # --- Gemini API Configuration --- | |
| if gemini_api_key: | |
| genai.configure(api_key=gemini_api_key) | |
| else: | |
| print("WARNING: GEMINI_API_KEY environment variable not set. Gemini models will not be available.") | |
| if groq_api_key: | |
| groq_client = Groq(api_key=groq_api_key) | |
| else: | |
| print("WARNING: GROQ_API_KEY environment variable not set. Groq models will not be available.") | |
| groq_client = None | |
| GEMINI_GENERATION_CONFIG = genai.types.GenerationConfig( | |
| temperature=0.2, | |
| max_output_tokens=450, | |
| top_p=0.8, | |
| ) | |
| # --- RAG & Data Source Setup --- | |
| try: | |
| gr.Info("Setting up the RAG retriever...") | |
| # If the local index directory doesn't exist, download it from Hugging Face. | |
| if not os.path.exists(LOCAL_DATA_DIR): | |
| if not RAG_SOURCE or not hf_token: | |
| raise ValueError("RAG index not found locally, and RAG_SOURCE or HF_TOKEN environment variables are not set. Cannot download index.") | |
| snapshot_download( | |
| repo_id=RAG_SOURCE, | |
| repo_type="dataset", | |
| token=hf_token, | |
| local_dir=LOCAL_DATA_DIR | |
| ) | |
| gr.Info("Index downloaded successfully.") | |
| else: | |
| gr.Info(f"Found existing local index at {LOCAL_DATA_DIR}.") | |
| # Load the RAG model from the (now existing) local index path. | |
| gr.Info(f'''Loading index from {os.path.join(LOCAL_DATA_DIR, "arxiv_colbert")}...''') | |
| RAG = RAGPretrainedModel.from_index(os.path.join(LOCAL_DATA_DIR, "arxiv_colbert")) | |
| _ = RAG.search("Test query", k=1) # Warm-up query | |
| gr.Info("Retriever loaded successfully!") | |
| except Exception as e: | |
| gr.Warning(f"Could not initialize the RAG retriever. The app may not function correctly. Error: {e}") | |
| RAG = None | |
| # --- UI Text and Metadata --- | |
| MARKDOWN_SEARCH_RESULTS_HEADER = '# 🔍 Search Results\n' | |
| APP_HEADER_TEXT = "# ArXiv CS RAG\n" | |
| INDEX_INFO = "Semantic Search" | |
| try: | |
| with open("README.md", "r") as f: | |
| mdfile = f.read() | |
| date_match = re.search(r'Index Last Updated : (\d{4}-\d{2}-\d{2})', mdfile) | |
| if date_match: | |
| date = date_match.group(1) | |
| formatted_date = datetime.strptime(date, '%Y-%m-%d').strftime('%d %b %Y') | |
| APP_HEADER_TEXT += f'Index Last Updated: {formatted_date}\n' | |
| INDEX_INFO = f"Semantic Search - up to {formatted_date}" | |
| except Exception: | |
| print("README.md not found or is invalid. Using default data source info.") | |
| DATABASE_CHOICES = [INDEX_INFO, 'Arxiv Search - Latest - (EXPERIMENTAL)'] | |
| ARX_CLIENT = arxiv.Client() | |
| # --- Helper Functions --- | |
| def get_prompt_messages(question, context): | |
| """Formats the prompt as reusable system and user messages.""" | |
| system_instruction = ( | |
| "You are writing the final answer shown to an end user in a scientific paper search app.\n" | |
| "Write one plain paragraph of 5-7 concise sentences.\n" | |
| "Use only the supplied source abstracts.\n" | |
| "Cite paper titles inline in parentheses.\n" | |
| "Do not use bullets, numbered lists, headings, labels, markdown, or analysis notes.\n" | |
| "Do not restate the question, instructions, constraints, or source list." | |
| ) | |
| user_message = f"Source abstracts:\n{context}\n\nUser question: {question}" | |
| return system_instruction, user_message | |
| def get_prompt_text(question, context): | |
| system_instruction, user_message = get_prompt_messages(question, context) | |
| return f"{system_instruction}\n\n{user_message}\n\nFinal answer:" | |
| def update_with_rag_md(message, llm_results_use, database_choice): | |
| """Fetches documents, updates the UI, and creates the final prompt for the LLM.""" | |
| prompt_context = "" | |
| rag_out = [] | |
| source_used = database_choice | |
| try: | |
| if database_choice == INDEX_INFO and RAG: | |
| rag_out = RAG.search(message, k=RETRIEVE_RESULTS) | |
| else: | |
| rag_out = list(ARX_CLIENT.results(arxiv.Search(query=message, max_results=RETRIEVE_RESULTS, sort_by=arxiv.SortCriterion.Relevance))) | |
| source_used = "Arxiv Search" | |
| if not rag_out: | |
| gr.Warning("Live Arxiv search returned no results. Falling back to semantic search.") | |
| if RAG: | |
| rag_out = RAG.search(message, k=RETRIEVE_RESULTS) | |
| source_used = INDEX_INFO | |
| except Exception as e: | |
| gr.Warning(f"An error occurred during search: {e}. Falling back to semantic search.") | |
| if RAG: | |
| rag_out = RAG.search(message, k=RETRIEVE_RESULTS) | |
| source_used = INDEX_INFO | |
| md_text_updated = MARKDOWN_SEARCH_RESULTS_HEADER | |
| for i, rag_answer in enumerate(rag_out): | |
| md_text_paper, prompt_text = get_md_text_abstract(rag_answer, source=source_used, return_prompt_formatting=True) | |
| if i < llm_results_use: | |
| prompt_context += f"{i+1}. {prompt_text}\n" | |
| md_text_updated += md_text_paper | |
| final_prompt = get_prompt_text(message, prompt_context) | |
| return md_text_updated, final_prompt | |
| def parse_llm_choice(llm_model_picked): | |
| if ":" not in llm_model_picked: | |
| return "gemini", llm_model_picked | |
| return llm_model_picked.split(":", 1) | |
| def split_prompt_for_chat(prompt): | |
| separator = "\n\nSource abstracts:\n" | |
| if separator not in prompt: | |
| return "", prompt.removesuffix("\n\nFinal answer:") | |
| system_instruction, user_body = prompt.split(separator, 1) | |
| user_body = user_body.removesuffix("\n\nFinal answer:") | |
| return system_instruction, f"Source abstracts:\n{user_body}" | |
| def ask_groq_llm(prompt, model_name, stream_outputs): | |
| if not groq_client: | |
| yield "Error: GROQ_API_KEY is not configured. Cannot contact Groq." | |
| return | |
| system_instruction, user_message = split_prompt_for_chat(prompt) | |
| response = groq_client.chat.completions.create( | |
| model=model_name, | |
| messages=[ | |
| {"role": "system", "content": system_instruction}, | |
| {"role": "user", "content": user_message}, | |
| ], | |
| temperature=0.2, | |
| top_p=0.8, | |
| max_completion_tokens=450, | |
| stream=stream_outputs, | |
| ) | |
| if stream_outputs: | |
| output = "" | |
| for chunk in response: | |
| text = chunk.choices[0].delta.content or "" | |
| output += text | |
| yield output | |
| if not output: | |
| yield "Model returned an empty response." | |
| else: | |
| yield response.choices[0].message.content or "Model returned an empty response." | |
| def ask_gemini_llm(prompt, model_name, stream_outputs): | |
| if not gemini_api_key: | |
| yield "Error: GEMINI_API_KEY is not configured. Cannot contact Gemini." | |
| return | |
| safety_settings = [ | |
| {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_ONLY_HIGH"}, | |
| {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_ONLY_HIGH"}, | |
| {"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_ONLY_HIGH"}, | |
| {"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_ONLY_HIGH"}, | |
| ] | |
| model = genai.GenerativeModel(model_name) | |
| response = model.generate_content( | |
| prompt, | |
| generation_config=GEMINI_GENERATION_CONFIG, | |
| stream=stream_outputs, | |
| safety_settings=safety_settings | |
| ) | |
| if stream_outputs: | |
| output = "" | |
| for chunk in response: | |
| try: | |
| text = chunk.parts[0].text | |
| output += text | |
| yield output | |
| except (IndexError, AttributeError): | |
| # Ignore empty chunks, which can occur at the end of a stream. | |
| pass | |
| if not output: | |
| yield "Model returned an empty or blocked stream. This may be due to the safety settings or the nature of the prompt." | |
| else: | |
| try: | |
| yield response.parts[0].text | |
| except (IndexError, AttributeError): | |
| yield "Model returned an empty or blocked response." | |
| def ask_llm(prompt, llm_model_picked, stream_outputs): | |
| """Sends a prompt to the selected LLM provider and streams the response.""" | |
| if not prompt or not prompt.strip(): | |
| yield "Error: The generated prompt is empty. Please try a different query." | |
| return | |
| if llm_model_picked == 'None': | |
| yield "LLM Model is disabled." | |
| return | |
| provider, model_name = parse_llm_choice(llm_model_picked) | |
| try: | |
| if provider == "groq": | |
| yield from ask_groq_llm(prompt, model_name, stream_outputs) | |
| return | |
| if provider == "gemini": | |
| yield from ask_gemini_llm(prompt, model_name, stream_outputs) | |
| return | |
| yield f"Error: Unsupported LLM provider '{provider}'." | |
| except Exception as e: | |
| error_message = f"An error occurred with the {provider} API: {e}" | |
| print(error_message) # Server side log | |
| gr.Warning(f"An error occurred with the {provider} API. Check the server logs for details.") | |
| yield error_message | |
| # --- Gradio User Interface --- | |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| gr.Markdown(APP_HEADER_TEXT) | |
| with gr.Group(): | |
| msg = gr.Textbox(label='Search', placeholder='e.g., What is Mixtral?') | |
| with gr.Accordion("Advanced Settings", open=False): | |
| llm_model = gr.Dropdown(choices=LLM_MODELS_TO_CHOOSE, value=DEFAULT_LLM_MODEL, label='LLM Model') | |
| llm_results = gr.Slider(3, 20, value=5, step=1, label="Top n results as context") | |
| database_src = gr.Dropdown(choices=DATABASE_CHOICES, value=INDEX_INFO, label='Search Source') | |
| stream_results = gr.Checkbox(value=True, label="Stream output") | |
| output_text = gr.Textbox(label='LLM Answer', placeholder="The model's answer will appear here...", interactive=False, lines=8) | |
| input_prompt = gr.Textbox(visible=False) | |
| gr_md = gr.Markdown(MARKDOWN_SEARCH_RESULTS_HEADER) | |
| msg.submit( | |
| fn=update_with_rag_md, | |
| inputs=[msg, llm_results, database_src], | |
| outputs=[gr_md, input_prompt] | |
| ).then( | |
| fn=ask_llm, | |
| inputs=[input_prompt, llm_model, stream_results], | |
| outputs=[output_text] | |
| ) | |
| if __name__ == "__main__": | |
| # Launch the app | |
| demo.queue().launch() | |