multirag / chatbot.py
gk8686's picture
Upload 15 files
c9b6cb4 verified
raw
history blame
3.18 kB
import os
import openai
import sys
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.vectorstores import DocArrayInMemorySearch
from langchain.document_loaders import TextLoader
from langchain.text_splitter import MarkdownHeaderTextSplitter
# from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.chat_models import ChatOpenAI
from langchain.prompts import PromptTemplate
from langchain.chains import RetrievalQA
class Chatbot:
def __init__(self, model_name, knowledge_base_path):
self.embeddings = OpenAIEmbeddings()
# self.path_to_root = "../"
# self.knowledge_base_path = os.path.join(self.path_to_root, knowledge_base_path)
self.knowledge_base_path = knowledge_base_path
print(self.knowledge_base_path)
self._load_knowledge_base(self.knowledge_base_path)
self._load_model(model_name)
self._load_prompt_template()
self._load_qa_chain()
def response(self, query):
return self.qa_chain({"query": query})["result"]
def _load_qa_chain(self):
assert self.LLM, "Model not loaded"
assert self.vector_store, "Knowledge base not loaded"
assert self.qa_chain_prompt, "Prompt template not loaded"
self.qa_chain = RetrievalQA.from_chain_type(
self.LLM,
retriever=self.vector_store.as_retriever(),
return_source_documents=True,
chain_type_kwargs={"prompt": self.qa_chain_prompt},
)
def _load_prompt_template(self):
template = """You are a helpful assistant, use the following context to answer the question at the end. If you don't know the answer, just say "sorry, I can't answer this, the answer to this question does not appear in my knowledge base", don't try to make up an answer.
{context}
Question: {question}
Helpful Answer:"""
self.qa_chain_prompt = PromptTemplate(
input_variables=["context", "question"], template=template
)
def _load_model(self, model_name):
self.LLM = ChatOpenAI(model_name=model_name, temperature=0)
def _load_knowledge_base(self, knowledge_base_path):
assert self.embeddings, "Embeddings not loaded"
assert os.path.exists(knowledge_base_path), "Knowledge base not found"
loader = TextLoader(knowledge_base_path)
corpus = loader.load()
text_data = " ".join([d.page_content for d in corpus])
# Since the knowledge base is formatted like a markdown, using a markdown header splitter to get splits on headers and header information in metadata.
split_on = [
("#", "Header 1"),
("##", "Header 2"),
("###", "Header 3"),
]
markdown_splitter = MarkdownHeaderTextSplitter(headers_to_split_on=split_on)
# markdown_splitter = RecursiveCharacterTextSplitter(
# chunk_size=1000, chunk_overlap=200, length_function=len
# )
md_header_splits = markdown_splitter.split_text(text_data)
self.vector_store = DocArrayInMemorySearch.from_documents(
md_header_splits, self.embeddings
)