File size: 6,172 Bytes
3701fee
 
 
 
 
 
 
 
 
3dbe475
 
 
 
 
 
 
 
 
 
 
 
3701fee
 
bae5df9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3701fee
3dbe475
3701fee
 
 
 
3dbe475
 
 
 
 
 
 
 
 
 
 
bae5df9
 
 
 
3dbe475
 
3701fee
 
3dbe475
3701fee
3dbe475
 
 
3701fee
3dbe475
 
3701fee
 
 
3dbe475
3701fee
3dbe475
 
 
 
 
 
 
 
 
 
 
 
3701fee
3dbe475
bae5df9
3dbe475
bae5df9
3701fee
3dbe475
 
 
 
 
 
 
bae5df9
3dbe475
 
bae5df9
 
 
 
3dbe475
3701fee
3dbe475
 
 
 
 
 
 
 
 
3701fee
3dbe475
 
 
 
3701fee
3dbe475
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3701fee
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
import torch
import transformers
import gradio as gr
from ragatouille import RAGPretrainedModel
import re
from datetime import datetime
import json
import arxiv

from helper import rag_cleaner, get_prompt_text, get_references, get_rag, SaveResponseAndRead, get_md_text_abstract, search_cleaner, get_arxiv_live_search

# Constants
RETRIEVE_RESULTS = 20
LLM_MODELS = ['mistralai/Mixtral-8x7B-Instruct-v0.1', 'mistralai/Mistral-7B-Instruct-v0.2', 'google/gemma-7b-it', 'None']
DEFAULT_LLM_MODEL = 'mistralai/Mistral-7B-Instruct-v0.2'
GENERATE_KWARGS = {
    "temperature": None,
    "max_new_tokens": 512,
    "top_p": None,
    "do_sample": False,
}

try:
    # RAG Model setup
    RAG = RAGPretrainedModel.from_index("colbert/indexes/arxiv_colbert")
    semantic_search_available = True

    try:
        gr.Info("Setting up retriever, please wait...")
        rag_initial_output = RAG.search("What is Generative AI in Healthcare?", k=1)
        gr.Info("Retriever working successfully!")
    except Exception as e:
        gr.Warning(f"Retriever not working: {str(e)}")

except FileNotFoundError:
    RAG = None
    semantic_search_available = False
    gr.Warning("Colbert index not found. Semantic search will be unavailable.")

# Header setup
mark_text = '# 🩺🔍 Search Results\n'
header_text = "## Arxiv Paper Summary With QA Retrieval Augmented Generation \n"

try:
    with open("README.md", "r") as f:
        mdfile = f.read()
    date_pattern = r'Index Last Updated : \d{4}-\d{2}-\d{2}'
    match = re.search(date_pattern, mdfile)
    date = match.group().split(': ')[1]
    formatted_date = datetime.strptime(date, '%Y-%m-%d').strftime('%d %b %Y')
    header_text += f'Index Last Updated: {formatted_date}\n'
    index_info = f"Semantic Search - up to {formatted_date}"
except FileNotFoundError:
    index_info = "Semantic Search"

if semantic_search_available:
    database_choices = [index_info, 'Arxiv Search - Latest']
else:
    database_choices = ['Arxiv Search - Latest']

# Arxiv API setup
arx_client = arxiv.Client()
is_arxiv_available = True
check_arxiv_result = get_arxiv_live_search("What is Self Rewarding AI and how can it be used in Multi-Agent Systems?", arx_client, RETRIEVE_RESULTS)
if len(check_arxiv_result) == 0:
    is_arxiv_available = False
    print("Arxiv search not working, switching to default search ...")
    database_choices = [index_info]

# Gradio UI setup
with gr.Blocks(theme=gr.themes.Soft()) as demo:
    header = gr.Markdown(header_text)
    
    with gr.Group():
        search_query = gr.Textbox(label='Search', placeholder='What is Generative AI in Healthcare?')
        
        with gr.Accordion("Advanced Settings", open=False):
            with gr.Row(equal_height=True):
                llm_model = gr.Dropdown(choices=LLM_MODELS, value=DEFAULT_LLM_MODEL, label='LLM Model')
                llm_results = gr.Slider(minimum=4, maximum=10, value=5, step=1, interactive=True, label="Top n results as context")
                database_src = gr.Dropdown(choices=database_choices, value=index_info, label='Search Source')
                stream_results = gr.Checkbox(value=True, label="Stream output", visible=False)
    
    output_text = gr.Textbox(show_label=True, container=True, label='LLM Answer', visible=True)
    input = gr.Textbox(show_label=False, visible=False)
    gr_md = gr.Markdown(mark_text)
    
    def update_with_rag_md(search_query, llm_results_use=5, database_choice=index_info, llm_model_picked=DEFAULT_LLM_MODEL):
        prompt_text_from_data = ""
        
        if database_choice == index_info and semantic_search_available:
            rag_out = get_rag(search_query, RAG, RETRIEVE_RESULTS)
            database_to_use = 'Semantic Search'
        else:
            arxiv_search_success = True
            try:
                rag_out = get_arxiv_live_search(search_query, arx_client, RETRIEVE_RESULTS)
                if len(rag_out) == 0:
                    arxiv_search_success = False
            except Exception as e:
                arxiv_search_success = False
                gr.Warning(f"Arxiv Search not working: {str(e)}")
            
            if not arxiv_search_success:
                gr.Warning("Arxiv search failed. Please try again later.")
                return "", ""
            
            database_to_use = 'Arxiv Search'
        
        md_text_updated = mark_text
        for i, rag_answer in enumerate(rag_out):
            if i < llm_results_use:
                md_text_paper, prompt_text = get_md_text_abstract(rag_answer, source=database_to_use, return_prompt_formatting=True)
                prompt_text_from_data += f"{i+1}. {prompt_text}"
            else:
                md_text_paper = get_md_text_abstract(rag_answer, source=database_to_use)
            md_text_updated += md_text_paper
        
        prompt = get_prompt_text(search_query, prompt_text_from_data, llm_model_picked=llm_model_picked)
        return md_text_updated, prompt
    
    def ask_llm(prompt, llm_model_picked=DEFAULT_LLM_MODEL, stream_outputs=False):
        model_disabled_text = "LLM Model is disabled"
        output = ""
        
        if llm_model_picked == 'None':
            if stream_outputs:
                for out in model_disabled_text:
                    output += out
                    yield output
            else:
                return model_disabled_text
        
        client = InferenceClient(llm_model_picked)
        try:
            response = client.text_generation(prompt, stream=stream_outputs, details=False, return_full_text=False, **GENERATE_KWARGS)
            
            if stream_outputs:
                for token in response:
                    output += token
                    yield SaveResponseAndRead(output)
            else:
                output = response
        except Exception as e:
            gr.Warning(f"LLM Inference failed: {str(e)}")
            output = ""
        
        return output
    
    search_query.submit(update_with_rag_md, [search_query, llm_results, database_src, llm_model], [gr_md, input]).success(ask_llm, [input, llm_model, stream_results], output_text)

demo.queue().launch()