JaganathC's picture
Update retrieval.py
7c21ccc verified
"""
LLM chain retrieval
"""
import json
import gradio as gr
from langchain.chains.conversational_retrieval.base import ConversationalRetrievalChain
from langchain.memory import ConversationBufferMemory
from langchain_huggingface import HuggingFaceEndpoint
from langchain_core.prompts import PromptTemplate
# Initialize langchain LLM chain
def initialize_llmchain(
llm_model,
huggingfacehub_api_token,
temperature,
max_tokens,
top_k,
vector_db,
progress=gr.Progress(),
):
"""Initialize Langchain LLM chain"""
progress(0.1, desc="Initializing HF tokenizer...")
progress(0.5, desc="Initializing HF Hub...")
llm = HuggingFaceEndpoint(
repo_id=llm_model,
task="text-generation",
provider="hf-inference",
temperature=temperature,
max_new_tokens=max_tokens,
top_k=top_k,
huggingfacehub_api_token=huggingfacehub_api_token,
)
progress(0.75, desc="Defining buffer memory...")
memory = ConversationBufferMemory(
memory_key="chat_history",
output_key="answer",
return_messages=True,
)
retriever = vector_db.as_retriever(search_type="similarity", search_kwargs={'k': top_k})
progress(0.8, desc="Defining retrieval chain...")
with open('prompt_template.json', 'r') as file:
system_prompt = json.load(file)
prompt_template = system_prompt["prompt"]
rag_prompt = PromptTemplate(
template=prompt_template, input_variables=["context", "question"]
)
qa_chain = ConversationalRetrievalChain.from_llm(
llm,
retriever=retriever,
chain_type="stuff",
memory=memory,
combine_docs_chain_kwargs={"prompt": rag_prompt},
return_source_documents=True,
verbose=False,
)
progress(0.9, desc="Done!")
return qa_chain
# Format chat history
def format_chat_history(message, chat_history):
"""Format chat history for LLM"""
formatted_chat_history = []
for user_message, bot_message in chat_history:
formatted_chat_history.append(f"User: {user_message}")
formatted_chat_history.append(f"Assistant: {bot_message}")
return formatted_chat_history
# Invoke QA chain with history
def invoke_qa_chain(qa_chain, message, history):
"""Invoke question-answering chain"""
formatted_chat_history = format_chat_history(message, history)
response = qa_chain.invoke({
"question": message,
"chat_history": formatted_chat_history,
})
response_sources = response["source_documents"]
response_answer = response["answer"]
# Clean up if "Helpful Answer:" is included
if "Helpful Answer:" in response_answer:
response_answer = response_answer.split("Helpful Answer:")[-1].strip()
new_history = history + [(message, response_answer)]
return qa_chain, new_history, response_sources