Arxiv-CS-RAG / app.py
Bishmoy Paul
fix: python 3.11 parsing fix
874e1c1
import gradio as gr
import google.generativeai as genai
from groq import Groq
from ragatouille import RAGPretrainedModel
import arxiv
import os
import re
from datetime import datetime
from utils import get_md_text_abstract
from huggingface_hub import snapshot_download
# --- Core Configuration ---
hf_token = os.getenv("HF_TOKEN")
gemini_api_key = os.getenv("GEMINI_API_KEY")
groq_api_key = os.getenv("GROQ_API_KEY")
RAG_SOURCE = os.getenv("RAG_SOURCE")
LOCAL_DATA_DIR = os.getenv("LOCAL_DATA_DIR", "./rag_index_data")
LLM_MODELS_TO_CHOOSE = [
"groq:llama-3.1-8b-instant",
"groq:llama-3.3-70b-versatile",
"gemini:gemma-4-26b-a4b-it",
"gemini:gemma-4-31b-it",
"None",
]
DEFAULT_LLM_MODEL = "groq:llama-3.1-8b-instant"
RETRIEVE_RESULTS = 20
# --- Gemini API Configuration ---
if gemini_api_key:
genai.configure(api_key=gemini_api_key)
else:
print("WARNING: GEMINI_API_KEY environment variable not set. Gemini models will not be available.")
if groq_api_key:
groq_client = Groq(api_key=groq_api_key)
else:
print("WARNING: GROQ_API_KEY environment variable not set. Groq models will not be available.")
groq_client = None
GEMINI_GENERATION_CONFIG = genai.types.GenerationConfig(
temperature=0.2,
max_output_tokens=450,
top_p=0.8,
)
# --- RAG & Data Source Setup ---
try:
gr.Info("Setting up the RAG retriever...")
# If the local index directory doesn't exist, download it from Hugging Face.
if not os.path.exists(LOCAL_DATA_DIR):
if not RAG_SOURCE or not hf_token:
raise ValueError("RAG index not found locally, and RAG_SOURCE or HF_TOKEN environment variables are not set. Cannot download index.")
snapshot_download(
repo_id=RAG_SOURCE,
repo_type="dataset",
token=hf_token,
local_dir=LOCAL_DATA_DIR
)
gr.Info("Index downloaded successfully.")
else:
gr.Info(f"Found existing local index at {LOCAL_DATA_DIR}.")
# Load the RAG model from the (now existing) local index path.
gr.Info(f'''Loading index from {os.path.join(LOCAL_DATA_DIR, "arxiv_colbert")}...''')
RAG = RAGPretrainedModel.from_index(os.path.join(LOCAL_DATA_DIR, "arxiv_colbert"))
_ = RAG.search("Test query", k=1) # Warm-up query
gr.Info("Retriever loaded successfully!")
except Exception as e:
gr.Warning(f"Could not initialize the RAG retriever. The app may not function correctly. Error: {e}")
RAG = None
# --- UI Text and Metadata ---
MARKDOWN_SEARCH_RESULTS_HEADER = '# 🔍 Search Results\n'
APP_HEADER_TEXT = "# ArXiv CS RAG\n"
INDEX_INFO = "Semantic Search"
try:
with open("README.md", "r") as f:
mdfile = f.read()
date_match = re.search(r'Index Last Updated : (\d{4}-\d{2}-\d{2})', mdfile)
if date_match:
date = date_match.group(1)
formatted_date = datetime.strptime(date, '%Y-%m-%d').strftime('%d %b %Y')
APP_HEADER_TEXT += f'Index Last Updated: {formatted_date}\n'
INDEX_INFO = f"Semantic Search - up to {formatted_date}"
except Exception:
print("README.md not found or is invalid. Using default data source info.")
DATABASE_CHOICES = [INDEX_INFO, 'Arxiv Search - Latest - (EXPERIMENTAL)']
ARX_CLIENT = arxiv.Client()
# --- Helper Functions ---
def get_prompt_messages(question, context):
"""Formats the prompt as reusable system and user messages."""
system_instruction = (
"You are writing the final answer shown to an end user in a scientific paper search app.\n"
"Write one plain paragraph of 5-7 concise sentences.\n"
"Use only the supplied source abstracts.\n"
"Cite paper titles inline in parentheses.\n"
"Do not use bullets, numbered lists, headings, labels, markdown, or analysis notes.\n"
"Do not restate the question, instructions, constraints, or source list."
)
user_message = f"Source abstracts:\n{context}\n\nUser question: {question}"
return system_instruction, user_message
def get_prompt_text(question, context):
system_instruction, user_message = get_prompt_messages(question, context)
return f"{system_instruction}\n\n{user_message}\n\nFinal answer:"
def update_with_rag_md(message, llm_results_use, database_choice):
"""Fetches documents, updates the UI, and creates the final prompt for the LLM."""
prompt_context = ""
rag_out = []
source_used = database_choice
try:
if database_choice == INDEX_INFO and RAG:
rag_out = RAG.search(message, k=RETRIEVE_RESULTS)
else:
rag_out = list(ARX_CLIENT.results(arxiv.Search(query=message, max_results=RETRIEVE_RESULTS, sort_by=arxiv.SortCriterion.Relevance)))
source_used = "Arxiv Search"
if not rag_out:
gr.Warning("Live Arxiv search returned no results. Falling back to semantic search.")
if RAG:
rag_out = RAG.search(message, k=RETRIEVE_RESULTS)
source_used = INDEX_INFO
except Exception as e:
gr.Warning(f"An error occurred during search: {e}. Falling back to semantic search.")
if RAG:
rag_out = RAG.search(message, k=RETRIEVE_RESULTS)
source_used = INDEX_INFO
md_text_updated = MARKDOWN_SEARCH_RESULTS_HEADER
for i, rag_answer in enumerate(rag_out):
md_text_paper, prompt_text = get_md_text_abstract(rag_answer, source=source_used, return_prompt_formatting=True)
if i < llm_results_use:
prompt_context += f"{i+1}. {prompt_text}\n"
md_text_updated += md_text_paper
final_prompt = get_prompt_text(message, prompt_context)
return md_text_updated, final_prompt
def parse_llm_choice(llm_model_picked):
if ":" not in llm_model_picked:
return "gemini", llm_model_picked
return llm_model_picked.split(":", 1)
def split_prompt_for_chat(prompt):
separator = "\n\nSource abstracts:\n"
if separator not in prompt:
return "", prompt.removesuffix("\n\nFinal answer:")
system_instruction, user_body = prompt.split(separator, 1)
user_body = user_body.removesuffix("\n\nFinal answer:")
return system_instruction, f"Source abstracts:\n{user_body}"
def ask_groq_llm(prompt, model_name, stream_outputs):
if not groq_client:
yield "Error: GROQ_API_KEY is not configured. Cannot contact Groq."
return
system_instruction, user_message = split_prompt_for_chat(prompt)
response = groq_client.chat.completions.create(
model=model_name,
messages=[
{"role": "system", "content": system_instruction},
{"role": "user", "content": user_message},
],
temperature=0.2,
top_p=0.8,
max_completion_tokens=450,
stream=stream_outputs,
)
if stream_outputs:
output = ""
for chunk in response:
text = chunk.choices[0].delta.content or ""
output += text
yield output
if not output:
yield "Model returned an empty response."
else:
yield response.choices[0].message.content or "Model returned an empty response."
def ask_gemini_llm(prompt, model_name, stream_outputs):
if not gemini_api_key:
yield "Error: GEMINI_API_KEY is not configured. Cannot contact Gemini."
return
safety_settings = [
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_ONLY_HIGH"},
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_ONLY_HIGH"},
{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_ONLY_HIGH"},
{"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_ONLY_HIGH"},
]
model = genai.GenerativeModel(model_name)
response = model.generate_content(
prompt,
generation_config=GEMINI_GENERATION_CONFIG,
stream=stream_outputs,
safety_settings=safety_settings
)
if stream_outputs:
output = ""
for chunk in response:
try:
text = chunk.parts[0].text
output += text
yield output
except (IndexError, AttributeError):
# Ignore empty chunks, which can occur at the end of a stream.
pass
if not output:
yield "Model returned an empty or blocked stream. This may be due to the safety settings or the nature of the prompt."
else:
try:
yield response.parts[0].text
except (IndexError, AttributeError):
yield "Model returned an empty or blocked response."
def ask_llm(prompt, llm_model_picked, stream_outputs):
"""Sends a prompt to the selected LLM provider and streams the response."""
if not prompt or not prompt.strip():
yield "Error: The generated prompt is empty. Please try a different query."
return
if llm_model_picked == 'None':
yield "LLM Model is disabled."
return
provider, model_name = parse_llm_choice(llm_model_picked)
try:
if provider == "groq":
yield from ask_groq_llm(prompt, model_name, stream_outputs)
return
if provider == "gemini":
yield from ask_gemini_llm(prompt, model_name, stream_outputs)
return
yield f"Error: Unsupported LLM provider '{provider}'."
except Exception as e:
error_message = f"An error occurred with the {provider} API: {e}"
print(error_message) # Server side log
gr.Warning(f"An error occurred with the {provider} API. Check the server logs for details.")
yield error_message
# --- Gradio User Interface ---
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown(APP_HEADER_TEXT)
with gr.Group():
msg = gr.Textbox(label='Search', placeholder='e.g., What is Mixtral?')
with gr.Accordion("Advanced Settings", open=False):
llm_model = gr.Dropdown(choices=LLM_MODELS_TO_CHOOSE, value=DEFAULT_LLM_MODEL, label='LLM Model')
llm_results = gr.Slider(3, 20, value=5, step=1, label="Top n results as context")
database_src = gr.Dropdown(choices=DATABASE_CHOICES, value=INDEX_INFO, label='Search Source')
stream_results = gr.Checkbox(value=True, label="Stream output")
output_text = gr.Textbox(label='LLM Answer', placeholder="The model's answer will appear here...", interactive=False, lines=8)
input_prompt = gr.Textbox(visible=False)
gr_md = gr.Markdown(MARKDOWN_SEARCH_RESULTS_HEADER)
msg.submit(
fn=update_with_rag_md,
inputs=[msg, llm_results, database_src],
outputs=[gr_md, input_prompt]
).then(
fn=ask_llm,
inputs=[input_prompt, llm_model, stream_results],
outputs=[output_text]
)
if __name__ == "__main__":
# Launch the app
demo.queue().launch()