isa / app.py
3gg's picture
Improve search results with tokenizer length function and by removing TOC from the pdf.
047da88
import gradio as gr
from langchain import HuggingFaceHub
from langchain.chains.question_answering import load_qa_chain
from langchain.document_loaders import PyMuPDFLoader
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.llms import HuggingFacePipeline
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import FAISS
from transformers import GPT2TokenizerFast
# Number of search results to query from the vector database.
SIMILARITY_SEARCH_COUNT = 8
# Size of each document chunk in number of tokens.
CHUNK_SIZE = 100
# Chunk overlap in number of tokens.
CHUNK_OVERLAP = 10
# Maximum number of output tokens.
MODEL_MAX_LENGTH = 500
print("Loading documents")
loader = PyMuPDFLoader("rdna3.pdf")
documents = loader.load()
print("Creating chunks")
tokenizer = GPT2TokenizerFast.from_pretrained('gpt2')
splitter = RecursiveCharacterTextSplitter.from_huggingface_tokenizer(
tokenizer, chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP)
chunks = splitter.split_documents(documents)
print("Creating database")
# Use own copy of model as a workaround for sentence-transformers' being down.
embeddings = HuggingFaceEmbeddings(model_name="3gg/all-mpnet-base-v2")
db = FAISS.from_documents(chunks, embeddings)
print("Loading model")
llm = HuggingFacePipeline.from_model_id(
model_id="google/flan-t5-large",
task="text2text-generation",
model_kwargs={"temperature": 0, "max_length": MODEL_MAX_LENGTH})
chain = load_qa_chain(llm, chain_type="stuff")
def ask(question):
answers = db.similarity_search(question, k=SIMILARITY_SEARCH_COUNT)
result = chain.run(input_documents=answers, question=question)
return result
# Warm up.
ask("What is VGPR")
iface = gr.Interface(
fn=ask,
inputs=gr.Textbox(label="Question", placeholder="What is..."),
outputs=gr.Textbox(label="Answer"),
allow_flagging="never")
iface.launch(share=False)