File size: 1,936 Bytes
91d72cc 0682ee2 5c402d8 0682ee2 804c5c7 0682ee2 047da88 91d72cc 5c402d8 047da88 5c402d8 047da88 074f5a4 047da88 5c402d8 6fcd382 5c402d8 804c5c7 047da88 0682ee2 804c5c7 047da88 0682ee2 804c5c7 600cadf 0682ee2 804c5c7 5c402d8 804c5c7 5c402d8 0682ee2 5c402d8 0682ee2 4ade5dc 0682ee2 f8d53e0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 |
import gradio as gr
from langchain import HuggingFaceHub
from langchain.chains.question_answering import load_qa_chain
from langchain.document_loaders import PyMuPDFLoader
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.llms import HuggingFacePipeline
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import FAISS
from transformers import GPT2TokenizerFast
# Number of search results to query from the vector database.
SIMILARITY_SEARCH_COUNT = 8
# Size of each document chunk in number of tokens.
CHUNK_SIZE = 100
# Chunk overlap in number of tokens.
CHUNK_OVERLAP = 10
# Maximum number of output tokens.
MODEL_MAX_LENGTH = 500
print("Loading documents")
loader = PyMuPDFLoader("rdna3.pdf")
documents = loader.load()
print("Creating chunks")
tokenizer = GPT2TokenizerFast.from_pretrained('gpt2')
splitter = RecursiveCharacterTextSplitter.from_huggingface_tokenizer(
tokenizer, chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP)
chunks = splitter.split_documents(documents)
print("Creating database")
# Use own copy of model as a workaround for sentence-transformers' being down.
embeddings = HuggingFaceEmbeddings(model_name="3gg/all-mpnet-base-v2")
db = FAISS.from_documents(chunks, embeddings)
print("Loading model")
llm = HuggingFacePipeline.from_model_id(
model_id="google/flan-t5-large",
task="text2text-generation",
model_kwargs={"temperature": 0, "max_length": MODEL_MAX_LENGTH})
chain = load_qa_chain(llm, chain_type="stuff")
def ask(question):
answers = db.similarity_search(question, k=SIMILARITY_SEARCH_COUNT)
result = chain.run(input_documents=answers, question=question)
return result
# Warm up.
ask("What is VGPR")
iface = gr.Interface(
fn=ask,
inputs=gr.Textbox(label="Question", placeholder="What is..."),
outputs=gr.Textbox(label="Answer"),
allow_flagging="never")
iface.launch(share=False)
|