File size: 2,670 Bytes
5e09e0a
5e80009
 
d522487
5e80009
067a3d4
 
5e80009
 
fa91957
5c64c3f
5e80009
 
 
 
5e09e0a
 
 
 
 
 
b93e0f2
5e09e0a
5e80009
 
 
 
 
c15373d
5e09e0a
067a3d4
 
5e80009
 
 
6a2dc29
5e80009
fa91957
5e80009
 
067a3d4
5e80009
 
fa91957
 
ee7a93d
938d0d4
25069ce
 
 
ee7a93d
5e09e0a
 
 
 
 
 
 
ade9294
5e09e0a
 
 
 
 
 
 
fa91957
1c17feb
1bfe6b8
5e09e0a
 
d5283e7
b93e0f2
938d0d4
 
5e09e0a
6690b2c
5e09e0a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import gradio as gr
from langchain.llms import Replicate
from langchain.vectorstores import Pinecone
from langchain.text_splitter import CharacterTextSplitter
from langchain.document_loaders import PyPDFLoader
from langchain.llms import HuggingFaceHub
from langchain.vectorstores import Chroma
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.chains import ConversationalRetrievalChain
from datasets import load_dataset
import os


key = os.environ.get('API')
os.environ["REPLICATE_API_TOKEN"] = key

import sentence_transformers

def loading_pdf():
    return "Loading..."

def pdf_changes(pdf_doc):
    
    loader = PyPDFLoader(pdf_doc.name)
    documents = loader.load()
    text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
    texts = text_splitter.split_documents(documents)
    
    embeddings = HuggingFaceEmbeddings()

    db = Chroma.from_documents(texts, embeddings)
    retriever = db.as_retriever(search_kwargs={'k': 2})

    llm = Replicate(
        model="a16z-infra/llama13b-v2-chat:df7690f1994d94e96ad9d568eac121aecf50684a0b0963b25a41cc40061269e5",
        input={"temperature": 0.2, "max_length": 3000, "length_penalty":0.1, "num_beams":3}
    )
    global qa 
    qa = ConversationalRetrievalChain.from_llm(
        llm,
        retriever,
        return_source_documents=True
    )
    return "Ready"


def query(history, text):
    langchain_history = [(msg[1], history[i+1][1] if i+1 < len(history) else "") for i, msg in enumerate(history) if i % 2 == 0]
    result = qa({"question": text, "chat_history": langchain_history})
    new_history = history + [(text,result['answer'])]
    return new_history,""

css="""
#col-container {max-width: 700px; margin-left: auto; margin-right: auto;}
"""

title = """
<div style="text-align: center;max-width: 700px;">
    <h1>Chat with PDF</h1>   
"""


with gr.Blocks(css=css) as demo:
    with gr.Column(elem_id="col-container"):
        gr.HTML(title)
        with gr.Column():
            pdf_doc = gr.File(label="Load a PDF", file_types=['.pdf'], type="file")
            load_pdf = gr.Button("Load PDF")
            langchain_status = gr.Textbox(label="Status", placeholder="", interactive=False)
        chatbot = gr.Chatbot([], elem_id="chatbot").style(height=350)
        question = gr.Textbox(label="Question", placeholder="Type your question and hit Enter ")
        submit_btn = gr.Button("Send message") 
    load_pdf.click(pdf_changes, inputs=[pdf_doc], outputs=[langchain_status], queue=False)
    question.submit(query, [chatbot, question], [chatbot, question])
    submit_btn.click(query, [chatbot, question], [chatbot, question])

demo.launch()