panel_example7 / src /chains.py
EasySci's picture
Upload 6 files
564d1d7
from langchain.document_loaders import PyPDFLoader, PyPDFDirectoryLoader
from langchain.indexes import VectorstoreIndexCreator
from langchain.llms import AzureOpenAI, OpenAI
from langchain.text_splitter import CharacterTextSplitter
from langchain.embeddings import OpenAIEmbeddings
from langchain.vectorstores import Chroma
from langchain.chains import RetrievalQA, ConversationalRetrievalChain, RetrievalQAWithSourcesChain
from langchain.chains.question_answering import load_qa_chain
from langchain.memory import ConversationBufferMemory
from langchain.chat_models import AzureChatOpenAI
import os
import openai
os.environ['CWD'] = os.getcwd()
# for testing
import src.constants as constants
# import constants
os.environ['OPENAI_API_KEY'] = constants.AZURE_OPENAI_KEY_FR
os.environ['OPENAI_API_BASE'] = constants.AZURE_OPENAI_ENDPOINT_FR
os.environ['OPENAI_API_VERSION'] = "2023-05-15"
os.environ['OPENAI_API_TYPE'] = "azure"
# openai.api_type = "azure"
# openai.api_base = constants.AZURE_OPENAI_ENDPOINT_FR
# openai.api_version = "2023-05-15"
openai.api_key = constants.OPEN_AI_KEY
def get_document_key(doc):
return doc.metadata['source'] + '_page_' + str(doc.metadata['page'])
import os
from typing import Optional
class PDFEmbeddings():
def __init__(self, path: Optional[str] = None):
self.path = path or os.path.join(os.environ['CWD'], 'archive')
self.text_splitter = CharacterTextSplitter(separator="\n", chunk_size=1000, chunk_overlap=200)
self.embeddings = OpenAIEmbeddings(deployment= constants.AZURE_ENGINE_NAME_US, chunk_size=1,
openai_api_key= constants.AZURE_OPENAI_KEY_US,
openai_api_base= constants.AZURE_OPENAI_ENDPOINT_US,
openai_api_version= "2023-05-15",
openai_api_type= "azure",)
self.vectorstore = Chroma(persist_directory=constants.persistent_dir, embedding_function=self.embeddings)
self.retriever = self.vectorstore.as_retriever(search_type = "similarity", search_kwags= {"k": 5})
self.memory = ConversationBufferMemory(memory_key='pdf_memory', return_messages=True)
def process_documents(self):
# Load the documents and process them
loader = PyPDFDirectoryLoader(self.path)
documents = loader.load()
chunks = self.text_splitter.split_documents(documents)
self.vectorstore.add_documents(chunks)
def search(self, query: str, chain_type: str = "stuff"):
chain = RetrievalQA.from_chain_type(llm= AzureChatOpenAI(deployment_name= constants.AZURE_ENGINE_NAME_FR, temperature=0),
retriever= self.retriever, chain_type= chain_type, return_source_documents= True)
result = chain({"query": query})
return result
def conversational_search(self, query: str, chain_type: str = "stuff"):
chain = ConversationalRetrievalChain.from_llm(llm= AzureChatOpenAI(deployment_name= constants.AZURE_ENGINE_NAME_FR),
retriever= self.retriever, memory= self.memory, chain_type= chain_type)
result = chain({"question": query})
return result['answer']
def load_and_run_chain(self, query: str, chain_type: str = "stuff"):
chain = load_qa_chain(llm= AzureChatOpenAI(deployment_name= constants.AZURE_ENGINE_NAME_FR), chain_type= chain_type)
return chain.run(input_documents = self.retriever, question = query)
if __name__ == '__main__':
pdf_embed = PDFEmbeddings()
# pdf_embed.process_documents() # This takes a while, so we only do it once
result = pdf_embed.search("Give me a list of short relevant queries to look for papers related to the topics of the papers in the source documents.")
print("\n\n", result['result'], "\n")
print("Source documents:")
for doc in result['source_documents']:
print(doc.metadata['source'])