llm-qa-bench / app_modules /llm_qa_chain.py
dh-mc's picture
supported flag APPLY_CHAT_TEMPLATE_FOR_RAG
01f4bd7
raw
history blame
No virus
4.42 kB
import json
import os
from typing import List
import pandas as pd
from langchain.chains import ConversationalRetrievalChain
from langchain.chains.base import Chain
from app_modules.llm_inference import LLMInference
from app_modules.utils import CustomizedConversationSummaryBufferMemory
from langchain_core.retrievers import BaseRetriever
from langchain_core.documents import Document
from langchain_core.callbacks.manager import CallbackManagerForRetrieverRun
from langchain.globals import get_debug
retrieve_from_questions_file = os.getenv("RETRIEVER_TYPE") == "questions_file"
apply_chat_template_for_rag = os.getenv("APPLY_CHAT_TEMPLATE_FOR_RAG") == "true"
print(f"retrieve_from_questions_file: {retrieve_from_questions_file}", flush=True)
print(f"apply_chat_template_for_rag: {apply_chat_template_for_rag}", flush=True)
if retrieve_from_questions_file:
questions_file_path = os.getenv("QUESTIONS_FILE_PATH")
questions_df = pd.read_json(questions_file_path)
print(f"Questions file loaded: {questions_file_path}", flush=True)
class DatasetRetriever(BaseRetriever):
def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
"""Get documents relevant to a query.
Args:
query: String to find relevant documents for
run_manager: The callbacks handler to use
Returns:
List of relevant documents
"""
docs = []
df = questions_df
# find the query in the df
filtered = df[df["question"].str.lower() == query.lower()]
# iterate over the filtered df
for i in range(len(filtered)):
docs.append(
Document(
page_content=filtered.iloc[i]["context"],
metadata={"source": filtered.iloc[i]["id"]},
)
)
if not docs:
print(f"No documents found for query: {query}", flush=True)
return docs
class QAChain(LLMInference):
def __init__(self, vectorstore, llm_loader):
super().__init__(llm_loader)
self.vectorstore = vectorstore
def create_chain(self) -> Chain:
if retrieve_from_questions_file:
retriever = DatasetRetriever()
else:
retriever = self.vectorstore.as_retriever(
search_kwargs=self.llm_loader.search_kwargs
)
if os.environ.get("CHAT_HISTORY_ENABLED") == "true":
memory = CustomizedConversationSummaryBufferMemory(
llm=self.llm_loader.llm,
output_key="answer",
memory_key="chat_history",
max_token_limit=1024,
return_messages=True,
)
qa = ConversationalRetrievalChain.from_llm(
self.llm_loader.llm,
memory=memory,
chain_type="stuff",
retriever=retriever,
get_chat_history=lambda h: h,
return_source_documents=True,
)
else:
qa = ConversationalRetrievalChain.from_llm(
self.llm_loader.llm,
retriever=retriever,
max_tokens_limit=8192, # self.llm_loader.max_tokens_limit,
return_source_documents=True,
)
return qa
def _process_inputs(self, inputs):
if isinstance(inputs, list) and self.llm_loader.llm_model_type == "huggingface":
inputs = [self.get_prompt(i) for i in inputs]
if get_debug():
print("_process_inputs:", json.dumps(inputs, indent=4))
return inputs
def get_prompt(self, inputs):
qa_system_prompt = "Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer."
df = questions_df
query = inputs["question"]
# find the query in the df
filtered = df[df["question"].str.lower() == query.lower()]
context = filtered.iloc[0]["context"] if len(filtered) > 0 else ""
if apply_chat_template_for_rag:
return self.apply_chat_template(
f"{qa_system_prompt}\n\n{context}\n\nQuestion: {query}"
)
else:
return f"{qa_system_prompt}\n\n{context}\n\nQuestion: {query}\n\nHelpful Answer:"