Spaces:
Runtime error
Runtime error
import os | |
import sys | |
import argparse | |
import pandas as pd | |
import time | |
from typing import Any, Dict, Optional | |
from langchain_core.callbacks import CallbackManagerForChainRun | |
from langchain.prompts import load_prompt | |
from langchain_core.output_parsers import StrOutputParser | |
from transformers import AutoTokenizer | |
current_dir = os.path.dirname(os.path.abspath(__file__)) | |
kit_dir = os.path.abspath(os.path.join(current_dir, "..")) | |
repo_dir = os.path.abspath(os.path.join(kit_dir, "..")) | |
sys.path.append(kit_dir) | |
sys.path.append(repo_dir) | |
from enterprise_knowledge_retriever.src.document_retrieval import DocumentRetrieval, RetrievalQAChain | |
class TimedRetrievalQAChain(RetrievalQAChain): | |
#override call method to return times | |
def _call(self, | |
inputs: Dict[str, Any], | |
run_manager: Optional[CallbackManagerForChainRun] = None, | |
) -> Dict[str, Any]: | |
qa_chain = self.qa_prompt | self.llm | StrOutputParser() | |
response = {} | |
start_time = time.time() | |
documents = self.retriever.invoke(inputs["question"]) | |
if self.rerank: | |
documents = self.rerank_docs(inputs["question"], documents, self.final_k_retrieved_documents) | |
docs = self._format_docs(documents) | |
end_preprocessing_time=time.time() | |
response["answer"] = qa_chain.invoke({"question": inputs["question"], "context": docs}) | |
end_llm_time=time.time() | |
response["source_documents"] = documents | |
response["start_time"] = start_time | |
response["end_preprocessing_time"] = end_preprocessing_time | |
response["end_llm_time"] = end_llm_time | |
return response | |
def analyze_times(answer, start_time, end_preprocessing_time, end_llm_time, tokenizer): | |
preprocessing_time=end_preprocessing_time-start_time | |
llm_time=end_llm_time-end_preprocessing_time | |
token_count=len(tokenizer.encode(answer)) | |
tokens_per_second = token_count / llm_time | |
perf = {"preprocessing_time": preprocessing_time, | |
"llm_time": llm_time, | |
"token_count": token_count, | |
"tokens_per_second": tokens_per_second} | |
return perf | |
def generate(qa_chain, question, tokenizer): | |
response = qa_chain.invoke({"question": question}) | |
answer = response.get('answer') | |
sources = set([ | |
f'{sd.metadata["filename"]}' | |
for sd in response["source_documents"] | |
]) | |
times = analyze_times( | |
answer, | |
response.get("start_time"), | |
response.get("end_preprocessing_time"), | |
response.get("end_llm_time"), | |
tokenizer | |
) | |
return answer, sources, times | |
def process_bulk_QA(vectordb_path, questions_file_path): | |
documentRetrieval = DocumentRetrieval() | |
tokenizer = AutoTokenizer.from_pretrained("NousResearch/Llama-2-7b-chat-hf") | |
if os.path.exists(vectordb_path): | |
# load the vectorstore | |
embeddings = documentRetrieval.load_embedding_model() | |
vectorstore = documentRetrieval.load_vdb(vectordb_path, embeddings) | |
print("Database loaded") | |
documentRetrieval.init_retriever(vectorstore) | |
print("retriever initialized") | |
#get qa chain | |
qa_chain = TimedRetrievalQAChain( | |
retriever=documentRetrieval.retriever, | |
llm=documentRetrieval.llm, | |
qa_prompt = load_prompt(os.path.join(kit_dir, documentRetrieval.prompts["qa_prompt"])), | |
rerank = documentRetrieval.retrieval_info["rerank"], | |
final_k_retrieved_documents = documentRetrieval.retrieval_info["final_k_retrieved_documents"] | |
) | |
else: | |
raise f"vector db path {vectordb_path} does not exist" | |
if os.path.exists(questions_file_path): | |
df = pd.read_excel(questions_file_path) | |
print(df) | |
output_file_path = questions_file_path.replace('.xlsx', '_output.xlsx') | |
if 'Answer' not in df.columns: | |
df['Answer'] = '' | |
df['Sources'] = '' | |
df['preprocessing_time'] = '' | |
df['llm_time'] = '' | |
df['token_count'] = '' | |
df['tokens_per_second'] = '' | |
for index, row in df.iterrows(): | |
if row['Answer'].strip()=='': # Only process if 'Answer' is empty | |
try: | |
# Generate the answer | |
print(f"Generating answer for row {index}") | |
answer, sources, times = generate(qa_chain, row['Questions'], tokenizer) | |
df.at[index, 'Answer'] = answer | |
df.at[index, 'Sources'] = sources | |
df.at[index, 'preprocessing_time'] = times.get("preprocessing_time") | |
df.at[index, 'llm_time'] = times.get("llm_time") | |
df.at[index, 'token_count'] = times.get("token_count") | |
df.at[index, 'tokens_per_second'] = times.get("tokens_per_second") | |
except Exception as e: | |
print(f"Error processing row {index}: {e}") | |
# Save the file after each iteration to avoid data loss | |
df.to_excel(output_file_path, index=False) | |
else: | |
print(f"Skipping row {index} because 'Answer' is already in the document") | |
return output_file_path | |
else: | |
raise f"questions file path {questions_file_path} does not exist" | |
if __name__ == "__main__": | |
# Parse the arguments | |
parser = argparse.ArgumentParser(description='use a vectordb and an excel file with questions in the first column and generate answers for all the questions') | |
parser.add_argument('vectordb_path', type=str, help='vector db path with stored documents for RAG') | |
parser.add_argument('questions_path', type=str, help='xlsx file containing questions in a column named Questions') | |
args = parser.parse_args() | |
# process in bulk | |
out_file = process_bulk_QA(args.vectordb_path, args.questions_path) | |
print(f"Finished, responses in: {out_file}") |