Spaces:
Sleeping
Sleeping
# import libraries | |
import os | |
import openai | |
from langchain_community.document_loaders import PyMuPDFLoader | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain_openai import OpenAIEmbeddings | |
from langchain_community.vectorstores import FAISS | |
from langchain.prompts import ChatPromptTemplate | |
from operator import itemgetter | |
from langchain_openai import ChatOpenAI | |
from langchain_core.output_parsers import StrOutputParser | |
from langchain_core.runnables import RunnablePassthrough | |
LLM_MODEL_NAME = "gpt-3.5-turbo" | |
# load PDF doc and convert to text | |
def load_pdf_to_text(pdf_path): | |
# create a document loader | |
loader = PyMuPDFLoader(pdf_path) | |
# load the document | |
doc = loader.load() | |
return doc | |
def split_text(text): | |
# create a text splitter | |
splitter = RecursiveCharacterTextSplitter( | |
chunk_size=700, | |
chunk_overlap=100, | |
) | |
# split the text | |
split_text = splitter.split_documents(text) | |
return split_text | |
# load text into FAISS index | |
def load_text_to_index(doc_splits): | |
embeddings = OpenAIEmbeddings( | |
model = "text-embedding-3-small" | |
) | |
vector_store = FAISS.from_documents(doc_splits, embeddings) | |
retriever = vector_store.as_retriever() | |
return retriever | |
# query FAISS index | |
def query_index(retriever, query): | |
retrieved_docs = retriever.invoke(query) | |
return retrieved_docs | |
# create answer prompt | |
def create_answer_prompt(): | |
template = """Answer the question based only on the following context. If you cannot answer the question with the context, please respond with 'I don't know': | |
Context: | |
{context} | |
Question: | |
{question} | |
""" | |
print("template: ", len(template)) | |
prompt = ChatPromptTemplate.from_template(template) | |
return prompt | |
# generate answer | |
def generate_answer(retriever, answer_prompt, query): | |
print("generate_answer()") | |
QnA_LLM = ChatOpenAI(model_name=LLM_MODEL_NAME, temperature=0.0) | |
retrieval_qna_chain = ( | |
{"context": itemgetter("question") | retriever, "question": itemgetter("question")} | |
| RunnablePassthrough.assign(context = itemgetter("context")) | |
| {"response": answer_prompt | QnA_LLM, "context": itemgetter("context")} | |
) | |
result = retrieval_qna_chain.invoke({"question": query}) | |
return result | |
def initialize_index(): | |
# load pdf | |
cwd = os.path.abspath(os.getcwd()) | |
data_dir = "data" | |
pdf_file = "nvidia_earnings_report.pdf" | |
# pdf_file = "musk-v-altman-openai-complaint-sf.pdf" | |
pdf_path = os.path.join(cwd, data_dir, pdf_file) | |
print("path: ", pdf_path) | |
doc = load_pdf_to_text(pdf_path) | |
print("doc: \n", len(doc)) | |
doc_splits = split_text(doc) | |
print("doc_splits length: \n", len(doc_splits)) | |
retriever = load_text_to_index(doc_splits) | |
return retriever | |
def main(): | |
retriever = initialize_index() | |
# query = "Who is the E-VP, Operations" | |
query = "what is the reason for the lawsuit" | |
retrieved_docs = query_index(retriever, query) | |
print("retrieved_docs: \n", len(retrieved_docs)) | |
answer_prompt = create_answer_prompt() | |
print("answer_prompt: \n", answer_prompt) | |
result = generate_answer(retriever, answer_prompt, query) | |
print("result: \n", result["response"].content) | |
if __name__ == "__main__": | |
main() | |