Research_Assistant / chatbot.py
lara1510's picture
Update chatbot.py
5b6fd3c verified
raw
history blame
No virus
4.34 kB
import os
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import Chroma
from langchain_core.prompts import ChatPromptTemplate
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain.chains import create_retrieval_chain, create_history_aware_retriever
from langchain_community.document_loaders import PyMuPDFLoader
from langchain_community.llms import Ollama
from langchain_core.messages import HumanMessage, AIMessage
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.prompts import MessagesPlaceholder
import torch
from transformers import pipeline
from transformers import AutoTokenizer, AutoModelForCausalLM
from langchain_huggingface import HuggingFacePipeline
class AdjustedHuggingFaceEmbeddings(HuggingFaceEmbeddings):
def __call__(self, input):
return super().__call__(input)
def create_chain(chains, pdf_doc):
if pdf_doc is None:
return 'You must convert or upload a pdf first'
db = create_vector_db(pdf_doc)
llm = create_model()
prompt_search_query = ChatPromptTemplate.from_messages([
MessagesPlaceholder(
variable_name="chat_history"),
("user", "{input}"),
("user",
"Given the above conversation, generate a search query to look up to get information relevant to the conversation")
])
retriever_chain = create_history_aware_retriever(llm, db.as_retriever(), prompt_search_query)
prompt_get_answer = ChatPromptTemplate.from_messages([
("system", "Answer the user's questions based on the below context:\\n\\n{context}"),
MessagesPlaceholder(variable_name="chat_history"),
("user", "{input}"),
])
combine_docs_chain = create_stuff_documents_chain(llm=llm, prompt=prompt_get_answer)
chains[0] = create_retrieval_chain(retriever_chain, combine_docs_chain)
return 'Document has successfully been loaded'
def create_model():
hf_api_token = os.getenv("HUGGINGFACEHUB_API_TOKEN")
tokenizer = AutoTokenizer.from_pretrained("OpenAssistant/oasst-sft-1-pythia-12b")
model = AutoModelForCausalLM.from_pretrained("OpenAssistant/oasst-sft-1-pythia-12b",
device_map='auto',
torch_dtype=torch.float16,
token=hf_api_token
)
pipe = pipeline("text-generation",
model=model,
tokenizer=tokenizer,
torch_dtype=torch.bfloat16,
device_map="auto",
max_new_tokens=1024,
do_sample=True,
top_k=10,
num_return_sequences=1,
eos_token_id=tokenizer.eos_token_id)
llm = HuggingFacePipeline(pipeline=pipe, model_kwargs={'temperature': 0})
return llm
def create_vector_db(doc):
document = load_document(doc)
text = split_document(document)
embedding = AdjustedHuggingFaceEmbeddings()
db = Chroma.from_documents(text, embedding)
return db
def load_document(doc):
loader = PyMuPDFLoader(doc.name)
document = loader.load()
return document
def split_document(doc):
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
text = text_splitter.split_documents(doc)
return text
def save_history(history):
with open('history.txt', 'w') as file:
for s in history:
file.write(f'- {s.content}\n')
def answer_query(chain, query: str, chat_history=None) -> str:
if chain:
# run the given chain with the given query and history
chat_history.append(HumanMessage(content=query))
response = chain.invoke({
'chat_history': chat_history,
'input': query
})
answer = response['answer']
print('RESPONSE: ', answer, '\n\n')
# add the current question and answer to history
chat_history.append(AIMessage(content=answer))
# save chat history to text file
save_history(chat_history)
return answer
else:
return "Please load a document first."