llm-qa-bench / qa_chain_test.py
dh-mc's picture
supported flag APPLY_CHAT_TEMPLATE_FOR_RAG
01f4bd7
raw
history blame
5.14 kB
import json
import os
import sys
import pandas as pd
from timeit import default_timer as timer
import nltk
chatting = len(sys.argv) > 1 and sys.argv[1] == "chat"
if chatting:
os.environ["BATCH_SIZE"] = "1"
from app_modules.init import app_init
from app_modules.llm_qa_chain import QAChain
from app_modules.utils import print_llm_response, calc_metrics, detect_repetition_scores
llm_loader, qa_chain = app_init()
if chatting:
print("Starting chat mode")
while True:
question = input("Please enter your question: ")
if question.lower() == "exit":
break
result = qa_chain.call_chain({"question": question, "chat_history": []}, None)
print_llm_response(result)
sys.exit(0)
num_of_questions = 0
if len(sys.argv) > 1:
num_of_questions = int(sys.argv[1])
# Create an empty DataFrame with column names
df = pd.DataFrame(
columns=[
"id",
"question",
"answer",
]
)
batch_size = int(os.getenv("BATCH_SIZE", "1"))
print(f"Batch size: {batch_size}")
questions_file_path = os.environ.get("QUESTIONS_FILE_PATH")
debug_retrieval = os.getenv("DEBUG_RETRIEVAL", "false").lower() == "true"
# Open the file for reading
print(f"Reading questions from file: {questions_file_path}")
test_data = json.loads(open(questions_file_path).read())
if isinstance(test_data, dict):
questions = [test_data[key] for key in test_data.keys()]
ids = [key for key in test_data.keys()]
else:
questions = test_data
ids = [row["id"] for row in questions]
if num_of_questions > 0:
questions = questions[:num_of_questions]
print(f"Number of questions: {len(questions)}")
if __name__ == "__main__":
chat_start = timer()
index = 0
while index < len(questions):
batch_ids = ids[index : index + batch_size]
batch_questions = [q["question"] for q in questions[index : index + batch_size]]
if isinstance(qa_chain, QAChain):
inputs = [{"question": q, "chat_history": []} for q in batch_questions]
else:
inputs = [{"question": q} for q in batch_questions]
start = timer()
result = qa_chain.call_chain(inputs, None)
end = timer()
print(f"Completed in {end - start:.3f}s")
# print("result:", result)
batch_answers = [r["answer"] for r in result]
for id, question, answer in zip(batch_ids, batch_questions, batch_answers):
df.loc[len(df)] = {
"id": id,
"question": question,
"answer": answer,
}
index += batch_size
for r in result:
print_llm_response(r, debug_retrieval)
chat_end = timer()
total_time = chat_end - chat_start
print(f"Total time used: {total_time:.3f} s")
df2 = pd.DataFrame(
columns=[
"id",
"question",
"answer",
"word_count",
"ground_truth",
]
)
for i in range(len(df)):
question = questions[i]
answer = df["answer"][i]
query = df["question"][i]
id = df["id"][i]
ground_truth = question[
"wellFormedAnswers" if "wellFormedAnswers" in question else "answers"
]
word_count = len(nltk.word_tokenize(answer))
df2.loc[len(df2)] = {
"id": id,
"question": query,
"answer": answer,
"word_count": word_count,
"ground_truth": ground_truth,
}
df2[["newline_score", "repetition_score", "total_repetitions"]] = df2[
"answer"
].apply(detect_repetition_scores)
pd.options.display.float_format = "{:.3f}".format
print(df2.describe())
word_count = df2["word_count"].sum()
csv_file = (
os.getenv("TEST_RESULTS_CSV_FILE") or f"qa_batch_{batch_size}_test_results.csv"
)
with open(csv_file, "w") as f:
f.write(
f"# RAG: {isinstance(qa_chain, QAChain)} questions: {questions_file_path}\n"
)
f.write(
f"# model: {llm_loader.model_name} repetition_penalty: {llm_loader.repetition_penalty}\n"
)
df2.to_csv(csv_file, mode="a", index=False, header=True)
print(f"test results saved to file: {csv_file}")
scores = calc_metrics(df2)
df = pd.DataFrame(
{
"model": [llm_loader.model_name],
"repetition_penalty": [llm_loader.repetition_penalty],
"word_count": [word_count],
"inference_time": [total_time],
"inference_speed": [word_count / total_time],
"bleu1": [scores["bleu_scores"]["bleu"]],
"rougeL": [scores["rouge_scores"]["rougeL"]],
}
)
print(f"Number of words generated: {word_count}")
print(f"Average generation speed: {word_count / total_time:.3f} words/s")
csv_file = os.getenv("ALL_RESULTS_CSV_FILE") or "qa_chain_all_results.csv"
file_existed = os.path.exists(csv_file) and os.path.getsize(csv_file) > 0
df.to_csv(csv_file, mode="a", index=False, header=not file_existed)
print(f"all results appended to file: {csv_file}")