krish-sri-bot / edubot.py
gsvc's picture
Update edubot.py
a48eb0d
from langchain import PromptTemplate
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import FAISS
from langchain.llms import CTransformers
from langchain.chains import RetrievalQA
from config import *
class EduBotCreator:
def __init__(self):
self.prompt_temp = PROMPT_TEMPLATE
self.input_variables = INP_VARS
self.chain_type = CHAIN_TYPE
self.search_kwargs = SEARCH_KWARGS
self.embedder = EMBEDDER
self.vector_db_path = VECTOR_DB_PATH
self.model_ckpt = MODEL_CKPT
self.model_type = MODEL_TYPE
self.max_new_tokens = MAX_NEW_TOKENS
self.temperature = TEMPERATURE
def create_custom_prompt(self):
custom_prompt_temp = PromptTemplate(template=self.prompt_temp,
input_variables=self.input_variables)
return custom_prompt_temp
def load_llm(self):
llm = CTransformers(
model = self.model_ckpt,
model_type=self.model_type,
max_new_tokens = self.max_new_tokens,
temperature = self.temperature
)
return llm
def load_vectordb(self):
hfembeddings = HuggingFaceEmbeddings(
model_name=self.embedder,
model_kwargs={'device': 'cpu'}
)
vector_db = FAISS.load_local(self.vector_db_path, hfembeddings)
return vector_db
def create_bot(self, custom_prompt, vectordb, llm):
retrieval_qa_chain = RetrievalQA.from_chain_type(
llm=llm,
chain_type=self.chain_type,
retriever=vectordb.as_retriever(search_kwargs=self.search_kwargs),
return_source_documents=True,
chain_type_kwargs={"prompt": custom_prompt}
)
return retrieval_qa_chain
def create_edubot(self):
self.custom_prompt = self.create_custom_prompt()
self.vector_db = self.load_vectordb()
self.llm = self.load_llm()
self.bot = self.create_bot(self.custom_prompt, self.vector_db, self.llm)
return self.bot