File size: 1,936 Bytes
91d72cc
0682ee2
 
5c402d8
0682ee2
804c5c7
0682ee2
 
047da88
91d72cc
 
5c402d8
047da88
5c402d8
047da88
 
074f5a4
047da88
 
5c402d8
 
6fcd382
5c402d8
 
804c5c7
047da88
0682ee2
 
804c5c7
047da88
 
 
0682ee2
 
804c5c7
600cadf
 
0682ee2
 
804c5c7
 
5c402d8
804c5c7
5c402d8
0682ee2
 
 
5c402d8
0682ee2
 
 
 
 
 
 
 
 
 
4ade5dc
0682ee2
f8d53e0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import gradio as gr
from langchain import HuggingFaceHub
from langchain.chains.question_answering import load_qa_chain
from langchain.document_loaders import PyMuPDFLoader
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.llms import HuggingFacePipeline
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import FAISS
from transformers import GPT2TokenizerFast


# Number of search results to query from the vector database.
SIMILARITY_SEARCH_COUNT = 8

# Size of each document chunk in number of tokens.
CHUNK_SIZE = 100

# Chunk overlap in number of tokens.
CHUNK_OVERLAP = 10

# Maximum number of output tokens.
MODEL_MAX_LENGTH = 500


print("Loading documents")
loader = PyMuPDFLoader("rdna3.pdf")
documents = loader.load()

print("Creating chunks")
tokenizer = GPT2TokenizerFast.from_pretrained('gpt2')
splitter = RecursiveCharacterTextSplitter.from_huggingface_tokenizer(
    tokenizer, chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP)
chunks = splitter.split_documents(documents)

print("Creating database")
# Use own copy of model as a workaround for sentence-transformers' being down.
embeddings = HuggingFaceEmbeddings(model_name="3gg/all-mpnet-base-v2")
db = FAISS.from_documents(chunks, embeddings)

print("Loading model")
llm = HuggingFacePipeline.from_model_id(
    model_id="google/flan-t5-large",
    task="text2text-generation",
    model_kwargs={"temperature": 0, "max_length": MODEL_MAX_LENGTH})
chain = load_qa_chain(llm, chain_type="stuff")

def ask(question):
    answers = db.similarity_search(question, k=SIMILARITY_SEARCH_COUNT)
    result = chain.run(input_documents=answers, question=question)
    return result

# Warm up.
ask("What is VGPR")

iface = gr.Interface(
    fn=ask,
    inputs=gr.Textbox(label="Question", placeholder="What is..."),
    outputs=gr.Textbox(label="Answer"),
    allow_flagging="never")

iface.launch(share=False)