Research_Assistant / chatbot.py
lara1510's picture
Update chatbot.py
6028f6f verified
raw
history blame
No virus
4.42 kB
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_huggingface.llms import HuggingFaceLLM
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import Chroma
from langchain_core.prompts import ChatPromptTemplate
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain.chains import create_retrieval_chain, create_history_aware_retriever
from langchain_community.document_loaders import PyMuPDFLoader
from langchain_community.llms import Ollama
from langchain_core.messages import HumanMessage, AIMessage
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.prompts import MessagesPlaceholder
import torch
from transformers import pipeline
from transformers import AutoTokenizer, AutoModelForCausalLM
from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline
class AdjustedHuggingFaceEmbeddings(HuggingFaceEmbeddings):
def __call__(self, input):
return super().__call__(input)
def create_chain(chains, pdf_doc):
if pdf_doc is None:
return 'You must convert or upload a pdf first'
db = create_vector_db(pdf_doc)
llm = create_model()
prompt_search_query = ChatPromptTemplate.from_messages([
MessagesPlaceholder(
variable_name="chat_history"),
("user", "{input}"),
("user",
"Given the above conversation, generate a search query to look up to get information relevant to the conversation")
])
retriever_chain = create_history_aware_retriever(llm, db.as_retriever(), prompt_search_query)
prompt_get_answer = ChatPromptTemplate.from_messages([
("system", "Answer the user's questions based on the below context:\\n\\n{context}"),
MessagesPlaceholder(variable_name="chat_history"),
("user", "{input}"),
])
combine_docs_chain = create_stuff_documents_chain(llm=llm, prompt=prompt_get_answer)
chains[0] = create_retrieval_chain(retriever_chain, combine_docs_chain)
return 'Document has successfully been loaded'
def create_model():
tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2",
device_map='auto',
torch_dtype=torch.float16,
use_auth_token=True,
load_in_8bit=True,
)
pipe = pipeline("text-generation",
model=model,
tokenizer= tokenizer,
torch_dtype=torch.bfloat16,
device_map="auto",
max_new_tokens = 1024,
do_sample=True,
top_k=10,
num_return_sequences=1,
eos_token_id=tokenizer.eos_token_id
)
llm = HuggingFacePipeline(pipeline=pipe, model_kwargs={'temperature':0})
return llm
def create_vector_db(doc):
document = load_document(doc)
text = split_document(document)
embedding = AdjustedHuggingFaceEmbeddings()
db = Chroma.from_documents(text, embedding)
return db
def load_document(doc):
loader = PyMuPDFLoader(doc.name)
document = loader.load()
return document
def split_document(doc):
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
text = text_splitter.split_documents(doc)
return text
def save_history(history):
with open('history.txt', 'w') as file:
for s in history:
file.write(f'- {s.content}\n')
def answer_query(chain, query: str, chat_history=None) -> str:
if chain:
# run the given chain with the given query and history
chat_history.append(HumanMessage(content=query))
response = chain.invoke({
'chat_history': chat_history,
'input': query
})
answer = response['answer']
print('RESPONSE: ', answer, '\n\n')
# add the current question and answer to history
chat_history.append(AIMessage(content=answer))
# save chat history to text file
save_history(chat_history)
return answer
else:
return "Please load a document first."