Spaces:
Runtime error
Runtime error
| import torch | |
| import transformers | |
| import gradio as gr | |
| from ragatouille import RAGPretrainedModel | |
| import re | |
| from datetime import datetime | |
| import json | |
| import arxiv | |
| from helper import rag_cleaner, get_prompt_text, get_references, get_rag, SaveResponseAndRead, get_md_text_abstract, search_cleaner, get_arxiv_live_search | |
| # Constants | |
| RETRIEVE_RESULTS = 20 | |
| LLM_MODELS = ['mistralai/Mixtral-8x7B-Instruct-v0.1', 'mistralai/Mistral-7B-Instruct-v0.2', 'google/gemma-7b-it', 'None'] | |
| DEFAULT_LLM_MODEL = 'mistralai/Mistral-7B-Instruct-v0.2' | |
| GENERATE_KWARGS = { | |
| "temperature": None, | |
| "max_new_tokens": 512, | |
| "top_p": None, | |
| "do_sample": False, | |
| } | |
| try: | |
| # RAG Model setup | |
| RAG = RAGPretrainedModel.from_index("colbert/indexes/arxiv_colbert") | |
| semantic_search_available = True | |
| try: | |
| gr.Info("Setting up retriever, please wait...") | |
| rag_initial_output = RAG.search("What is Generative AI in Healthcare?", k=1) | |
| gr.Info("Retriever working successfully!") | |
| except Exception as e: | |
| gr.Warning(f"Retriever not working: {str(e)}") | |
| except FileNotFoundError: | |
| RAG = None | |
| semantic_search_available = False | |
| gr.Warning("Colbert index not found. Semantic search will be unavailable.") | |
| # Header setup | |
| mark_text = '# 🩺🔍 Search Results\n' | |
| header_text = "## Arxiv Paper Summary With QA Retrieval Augmented Generation \n" | |
| try: | |
| with open("README.md", "r") as f: | |
| mdfile = f.read() | |
| date_pattern = r'Index Last Updated : \d{4}-\d{2}-\d{2}' | |
| match = re.search(date_pattern, mdfile) | |
| date = match.group().split(': ')[1] | |
| formatted_date = datetime.strptime(date, '%Y-%m-%d').strftime('%d %b %Y') | |
| header_text += f'Index Last Updated: {formatted_date}\n' | |
| index_info = f"Semantic Search - up to {formatted_date}" | |
| except FileNotFoundError: | |
| index_info = "Semantic Search" | |
| if semantic_search_available: | |
| database_choices = [index_info, 'Arxiv Search - Latest'] | |
| else: | |
| database_choices = ['Arxiv Search - Latest'] | |
| # Arxiv API setup | |
| arx_client = arxiv.Client() | |
| is_arxiv_available = True | |
| check_arxiv_result = get_arxiv_live_search("What is Self Rewarding AI and how can it be used in Multi-Agent Systems?", arx_client, RETRIEVE_RESULTS) | |
| if len(check_arxiv_result) == 0: | |
| is_arxiv_available = False | |
| print("Arxiv search not working, switching to default search ...") | |
| database_choices = [index_info] | |
| # Gradio UI setup | |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| header = gr.Markdown(header_text) | |
| with gr.Group(): | |
| search_query = gr.Textbox(label='Search', placeholder='What is Generative AI in Healthcare?') | |
| with gr.Accordion("Advanced Settings", open=False): | |
| with gr.Row(equal_height=True): | |
| llm_model = gr.Dropdown(choices=LLM_MODELS, value=DEFAULT_LLM_MODEL, label='LLM Model') | |
| llm_results = gr.Slider(minimum=4, maximum=10, value=5, step=1, interactive=True, label="Top n results as context") | |
| database_src = gr.Dropdown(choices=database_choices, value=index_info, label='Search Source') | |
| stream_results = gr.Checkbox(value=True, label="Stream output", visible=False) | |
| output_text = gr.Textbox(show_label=True, container=True, label='LLM Answer', visible=True) | |
| input = gr.Textbox(show_label=False, visible=False) | |
| gr_md = gr.Markdown(mark_text) | |
| def update_with_rag_md(search_query, llm_results_use=5, database_choice=index_info, llm_model_picked=DEFAULT_LLM_MODEL): | |
| prompt_text_from_data = "" | |
| if database_choice == index_info and semantic_search_available: | |
| rag_out = get_rag(search_query, RAG, RETRIEVE_RESULTS) | |
| database_to_use = 'Semantic Search' | |
| else: | |
| arxiv_search_success = True | |
| try: | |
| rag_out = get_arxiv_live_search(search_query, arx_client, RETRIEVE_RESULTS) | |
| if len(rag_out) == 0: | |
| arxiv_search_success = False | |
| except Exception as e: | |
| arxiv_search_success = False | |
| gr.Warning(f"Arxiv Search not working: {str(e)}") | |
| if not arxiv_search_success: | |
| gr.Warning("Arxiv search failed. Please try again later.") | |
| return "", "" | |
| database_to_use = 'Arxiv Search' | |
| md_text_updated = mark_text | |
| for i, rag_answer in enumerate(rag_out): | |
| if i < llm_results_use: | |
| md_text_paper, prompt_text = get_md_text_abstract(rag_answer, source=database_to_use, return_prompt_formatting=True) | |
| prompt_text_from_data += f"{i+1}. {prompt_text}" | |
| else: | |
| md_text_paper = get_md_text_abstract(rag_answer, source=database_to_use) | |
| md_text_updated += md_text_paper | |
| prompt = get_prompt_text(search_query, prompt_text_from_data, llm_model_picked=llm_model_picked) | |
| return md_text_updated, prompt | |
| def ask_llm(prompt, llm_model_picked=DEFAULT_LLM_MODEL, stream_outputs=False): | |
| model_disabled_text = "LLM Model is disabled" | |
| output = "" | |
| if llm_model_picked == 'None': | |
| if stream_outputs: | |
| for out in model_disabled_text: | |
| output += out | |
| yield output | |
| else: | |
| return model_disabled_text | |
| client = InferenceClient(llm_model_picked) | |
| try: | |
| response = client.text_generation(prompt, stream=stream_outputs, details=False, return_full_text=False, **GENERATE_KWARGS) | |
| if stream_outputs: | |
| for token in response: | |
| output += token | |
| yield SaveResponseAndRead(output) | |
| else: | |
| output = response | |
| except Exception as e: | |
| gr.Warning(f"LLM Inference failed: {str(e)}") | |
| output = "" | |
| return output | |
| search_query.submit(update_with_rag_md, [search_query, llm_results, database_src, llm_model], [gr_md, input]).success(ask_llm, [input, llm_model, stream_results], output_text) | |
| demo.queue().launch() |