File size: 2,730 Bytes
6dcf428
 
742e3f8
 
 
 
 
 
6dcf428
742e3f8
6dcf428
742e3f8
6dcf428
cec952b
 
 
6dcf428
 
cec952b
6dcf428
742e3f8
 
cec952b
742e3f8
 
 
 
 
6dcf428
3284c4f
 
 
6dcf428
 
 
742e3f8
6dcf428
 
 
742e3f8
1ff0039
6dcf428
9e0554c
575ba93
6dcf428
 
 
 
 
1ff0039
6dcf428
 
 
 
 
cec952b
 
 
 
 
 
6dcf428
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import arxiv
import gradio as gr
from llama_index import (
    VectorStoreIndex,
    ServiceContext,
    SimpleDirectoryReader,
    Document
)
from langchain.llms import HuggingFaceHub
from llama_index.llms import LangChainLLM

repo_id = 'HuggingFaceH4/zephyr-7b-beta'

MAX_MAX_NEW_TOKENS = 10240
DEFAULT_MAX_NEW_TOKENS = 4096

def loading_paper(): return 'Loading...'

def paper_changes(paper_id, temperature=0.2, max_tokens=4096, top_p=0.9):
    paper = next(arxiv.Client().results(arxiv.Search(id_list=[paper_id])))
    docs = SimpleDirectoryReader(input_files=[paper.download_pdf()]).load_data()
    doc = Document(text='\n\n'.join([doc.text for doc in docs]))
    llm = LangChainLLM(llm=HuggingFaceHub(repo_id=repo_id, model_kwargs={'temperature': temperature, 'max_tokens': max_tokens, 'top_p': top_p}))
    service_context = ServiceContext.from_defaults(llm=llm, embed_model="local:BAAI/bge-small-en-v1.5")
    index = VectorStoreIndex.from_documents([doc], service_context=service_context)
    global query_engine
    query_engine = index.as_query_engine()
    return 'Ready!!!'

def add_text(history, text):
    history = history + [(text, None)]
    return history, ''

def bot(history):
    response = infer(history[-1][0])
    history[-1][1] = response
    return history

def infer(question):
    response = query_engine.query(question)
    return str(response)

with gr.Blocks(theme='WeixuanYuan/Soft_dark') as demo:
    with gr.Column(variant='panel', scale=2):
        chatbot = gr.Chatbot([], elem_id='chatbot')

        with gr.Row():
            paper_id = gr.Textbox(label='ArXiv Paper Id', placeholder='1706.03762')
            langchain_status = gr.Textbox(label='Status', placeholder='', interactive=False)
            load_paper = gr.Button('Load Paper to LLaMa-Index')

        with gr.Row():
            question = gr.Textbox(label='Question', placeholder='Type your query...')
            submit_btn = gr.Button('Submit')

        with gr.Accordion(label='Advanced options', open=False):
            max_new_tokens = gr.Slider(label='Max New Tokens', minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS)
            temperature = gr.Slider(label='Temperature', minimum=0.1, maximum=4.0, step=0.1, value=0.1)
            top_p = gr.Slider(label='Top-P (nucleus sampling)', minimum=0.05, maximum=1.0, step=0.05, value=0.9)

    load_paper.click(paper_changes, inputs=[paper_id, temperature, max_new_tokens, top_p], outputs=[langchain_status], queue=False)
    question.submit(add_text, [chatbot, question], [chatbot, question]).then(bot, chatbot, chatbot)
    submit_btn.click(add_text, [chatbot, question], [chatbot, question]).then(bot, chatbot, chatbot)

demo.launch(share=True)