llm-ai-assistant / src /worker_huggingface.py
0xdant's picture
Prompt improvements, replace deprecated methods and chat history support
e8445ba
import os
import torch
from langchain_core.prompts import PromptTemplate
from langchain_community.embeddings import HuggingFaceInstructEmbeddings
from langchain_community.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import FAISS
from langchain.chains import create_retrieval_chain
from langchain_core.prompts import ChatPromptTemplate
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain_huggingface import HuggingFaceEndpoint
from langchain_core.prompts import MessagesPlaceholder
from langchain.chains import create_history_aware_retriever
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain_community.chat_message_histories import ChatMessageHistory
# Check for GPU availability and set the appropriate device for computation.
DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
# Global variables
conversation_retrieval_chain = None
chat_history = []
llm_hub = None
embeddings = None
tokenizer = None
# Function to initialize the language model and its embeddings
def init_llm():
global llm_hub, embeddings
# Hugging Face API token
# Setup environment variable HUGGINGFACEHUB_API_TOKEN
model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
llm_hub = HuggingFaceEndpoint(
repo_id=model_id,
task="text-generation",
max_new_tokens=200,
do_sample=False,
repetition_penalty=1.03,
return_full_text=False,
temperature=0.1,
)
from langchain.embeddings import HuggingFaceEmbeddings
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
store = {}
def get_session_history(session_id: str) -> BaseChatMessageHistory:
if session_id not in store:
store[session_id] = ChatMessageHistory()
return store[session_id]
# Function to process a PDF document
def process_document(document_path):
global conversation_retrieval_chain
# Load the document
loader = PyPDFLoader(document_path)
documents = loader.load()
# Split the document into chunks
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1024, chunk_overlap=128)
texts = text_splitter.split_documents(documents)
# Create an embeddings database using FAISS from the split text chunks.
db = FAISS.from_documents(documents=texts, embedding=embeddings)
system_prompt = """
<|start_header_id|>user<|end_header_id|>
You are an assistant for answering questions using provided context.
You are given the extracted parts of a long document, previous chat_history and a question. Provide a conversational answer.
If you don't know the answer, just say "I do not know." Don't make up an answer.
Question: {input}
Context: {context}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
"""
prompt = ChatPromptTemplate.from_messages(
[
("system", system_prompt),
("human", "{input}"),
]
)
retriever=db.as_retriever(search_type="similarity", search_kwargs={'k': 3, 'lambda_mult': 0.25})
question_answer_chain = create_stuff_documents_chain(llm_hub, prompt)
# conversation_retrieval_chain = create_retrieval_chain(retriever, question_answer_chain)
contextualize_q_system_prompt = (
"Given a chat history and the latest user question "
"which might reference context in the chat history, "
"formulate a standalone question which can be understood "
"without the chat history. Do NOT answer the question, "
"just reformulate it if needed and otherwise return it as is."
)
contextualize_q_prompt = ChatPromptTemplate.from_messages(
[
("system", contextualize_q_system_prompt),
MessagesPlaceholder("chat_history"),
("human", "{input}"),
]
)
history_aware_retriever = create_history_aware_retriever(
llm_hub, retriever, contextualize_q_prompt
)
rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain)
conversation_retrieval_chain = RunnableWithMessageHistory(
rag_chain,
get_session_history,
input_messages_key="input",
history_messages_key="chat_history",
output_messages_key="answer",
)
# Function to process a user prompt
def process_prompt(prompt):
# global conversation_retrieval_chain
global chat_history
# Query the model with history
output = conversation_retrieval_chain.invoke(
{"input": prompt},
config={
"configurable": {"session_id": "abc123"}
}, # constructs a key "abc123" in `store`.
)
answer = output["answer"]
print(output)
# Return the model's response
return answer
# Initialize the language model
init_llm()