|
import os |
|
|
|
from langchain import FAISS, OpenAI, HuggingFaceHub, Cohere, PromptTemplate |
|
from langchain.chains import RetrievalQA, ConversationalRetrievalChain |
|
from langchain.embeddings import OpenAIEmbeddings, HuggingFaceEmbeddings, CohereEmbeddings |
|
from langchain.memory import ConversationBufferMemory |
|
from langchain.schema import Document |
|
from langchain.text_splitter import RecursiveCharacterTextSplitter, CharacterTextSplitter, NLTKTextSplitter, \ |
|
SpacyTextSplitter |
|
from langchain.vectorstores import Chroma, ElasticVectorSearch |
|
from pypdf import PdfReader |
|
|
|
from schema import EmbeddingTypes, IndexerType, TransformType, BotType |
|
|
|
|
|
class QnASystem: |
|
|
|
def read_and_load_pdf(self, f_data): |
|
pdf_data = PdfReader(f_data) |
|
documents = [] |
|
for idx, page in enumerate(pdf_data.pages): |
|
documents.append(Document(page_content=page.extract_text(), |
|
metadata={"page_no": idx, "source": f_data.name})) |
|
|
|
self.documents = documents |
|
|
|
def document_transformer(self, transform_type: TransformType): |
|
match transform_type: |
|
case TransformType.CharacterTransform: |
|
t_type = CharacterTextSplitter(chunk_size=1000, chunk_overlap=20) |
|
case TransformType.RecursiveTransform: |
|
t_type = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=20) |
|
case TransformType.NLTKTransform: |
|
t_type = NLTKTextSplitter() |
|
case TransformType.SpacyTransform: |
|
t_type = SpacyTextSplitter() |
|
|
|
case _: |
|
raise IndexError("Invalid Transformer Type") |
|
|
|
self.transformed_documents = t_type.split_documents(documents=self.documents) |
|
|
|
def generate_embeddings(self, embedding_type: EmbeddingTypes = EmbeddingTypes.OPENAI, |
|
indexer_type: IndexerType = IndexerType.FAISS, **kwargs): |
|
temperature = kwargs.get("temperature", 0) |
|
max_tokens = kwargs.get("max_tokens", 512) |
|
match embedding_type: |
|
case EmbeddingTypes.OPENAI: |
|
os.environ["OPENAI_API_KEY"] = kwargs.get("api_key") or os.getenv("OPENAI_API_KEY") |
|
embeddings = OpenAIEmbeddings() |
|
llm = OpenAI(temperature=temperature, max_tokens=max_tokens) |
|
case EmbeddingTypes.HUGGING_FACE: |
|
embeddings = HuggingFaceEmbeddings(model_name=kwargs.get("model_name")) |
|
llm = HuggingFaceHub(repo_id=kwargs.get("model_name"), |
|
model_kwargs={"temperature": temperature, "max_tokens": max_tokens}) |
|
case EmbeddingTypes.COHERE: |
|
embeddings = CohereEmbeddings(model=kwargs.get("model_name"), cohere_api_key=kwargs.get("api_key")) |
|
llm = Cohere(model=kwargs.get("model_name"), cohere_api_key=kwargs.get("api_key"), |
|
model_kwargs={"temperature": temperature, |
|
"max_tokens": max_tokens}) |
|
case _: |
|
raise IndexError("Invalid Embedding Type") |
|
|
|
match indexer_type: |
|
case IndexerType.FAISS: |
|
indexer = FAISS |
|
case IndexerType.CHROMA: |
|
indexer = Chroma() |
|
|
|
case IndexerType.ELASTICSEARCH: |
|
indexer = ElasticVectorSearch(elasticsearch_url=kwargs.get("elasticsearch_url")) |
|
case _: |
|
raise IndexError("Invalid Indexer Function") |
|
|
|
self.llm = llm |
|
self.indexer = indexer |
|
self.vector_store = indexer.from_documents(documents=self.transformed_documents, embedding=embeddings) |
|
|
|
def get_retriever(self, search_type="similarity", top_k=5, **kwargs): |
|
retriever = self.vector_store.as_retriever(search_type=search_type, search_kwargs={"k": top_k}) |
|
self.retriever = retriever |
|
|
|
def get_prompt(self, bot_type: BotType, **kwargs): |
|
match bot_type: |
|
case BotType.qna: |
|
prompt = """ |
|
You are a smart and helpful AI assistant, who answer the question given context |
|
{context} |
|
Question: {question} |
|
""" |
|
case BotType.conversational: |
|
prompt = """ |
|
Given the following conversation and a follow up question, |
|
rephrase the follow up question to be a standalone question, in its original language. |
|
\nChat History:\n{chat_history}\nFollow Up Input: {question}\nStandalone question: |
|
""" |
|
return PromptTemplate(input_variables=["context", "question", "chat_history"], template=prompt) |
|
|
|
def build_qa(self, qa_type: BotType, chain_type="stuff", |
|
return_documents: bool = True, **kwargs): |
|
match qa_type: |
|
case BotType.qna: |
|
self.chain = RetrievalQA.from_chain_type(llm=self.llm, retriever=self.retriever, chain_type=chain_type, |
|
return_source_documents=return_documents, verbose=True) |
|
|
|
case BotType.conversational: |
|
self.memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True, |
|
output_key="answer") |
|
self.chain = ConversationalRetrievalChain.from_llm(llm=self.llm, retriever=self.retriever, |
|
chain_type=chain_type, |
|
return_source_documents=return_documents, |
|
memory=self.memory, verbose=True) |
|
|
|
case _: |
|
raise IndexError("Invalid QA Type") |
|
|
|
def ask_question(self, query): |
|
if type(self.chain) == RetrievalQA: |
|
data = {"query": query} |
|
else: |
|
data = {"question": query} |
|
return self.chain(data) |
|
|
|
def build_chain(self, transform_type, embedding_type, indexer_type, **kwargs): |
|
if hasattr(self, "llm"): |
|
return self.chain |
|
self.document_transformer(transform_type) |
|
self.generate_embeddings(embedding_type=embedding_type, |
|
indexer_type=indexer_type, **kwargs) |
|
self.get_retriever(**kwargs) |
|
qa = self.build_qa(qa_type=kwargs.get("bot_type"), **kwargs) |
|
return qa |
|
|
|
|
|
if __name__ == "__main__": |
|
qna = QnASystem() |
|
with open("../docs/Doc A.pdf", "rb") as f: |
|
qna.read_and_load_pdf(f) |
|
chain = qna.build_chain( |
|
transform_type=TransformType.RecursiveTransform, |
|
embedding_type=EmbeddingTypes.OPENAI, indexer_type=IndexerType.FAISS, |
|
chain_type="map_reduce", bot_type=BotType.conversational, return_documents=True |
|
) |
|
question = qna.ask_question(query="Hi! Summarize the document.") |
|
question = qna.ask_question(query="What happened from June 1984 to September 1996") |
|
print(question) |
|
|