ChatCSV / indexes /query_engine_bk.py
Chamin09's picture
Rename indexes/query_engine.py to indexes/query_engine_bk.py
3613a7e verified
from typing import Dict, List, Optional, Any
from llama_index.query_engine import RetrieverQueryEngine
from llama_index.retrievers import VectorIndexRetriever
from llama_index.response_synthesizers import ResponseMode
from llama_index.llms import HuggingFaceLLM
from llama_index import ServiceContext, QueryBundle
from llama_index.prompts import PromptTemplate
class CSVQueryEngine:
"""Query engine for CSV data with multi-file support."""
def __init__(self, index_manager, llm, response_mode="compact"):
"""Initialize with index manager and language model."""
self.index_manager = index_manager
self.llm = llm
self.service_context = ServiceContext.from_defaults(llm=llm)
self.response_mode = response_mode
# Set up custom prompts
self._setup_prompts()
def _setup_prompts(self):
"""Set up custom prompts for CSV querying."""
self.csv_query_prompt = PromptTemplate(
"""You are an AI assistant specialized in analyzing CSV data.
Answer the following query using the provided CSV information.
If calculations are needed, explain your process.
CSV Context: {context_str}
Query: {query_str}
Answer:"""
)
def query(self, query_text: str) -> Dict[str, Any]:
"""Process a natural language query across CSV files."""
# Find relevant CSV files
relevant_csvs = self.index_manager.find_relevant_csvs(query_text)
if not relevant_csvs:
return {
"answer": "No relevant CSV files found for your query.",
"sources": []
}
# Prepare response
responses = []
sources = []
# Query each relevant CSV
for csv_id in relevant_csvs:
index_info = self.index_manager.indexes.get(csv_id)
if not index_info:
continue
index = index_info["index"]
metadata = index_info["metadata"]
# Create retriever for this index
retriever = VectorIndexRetriever(
index=index,
similarity_top_k=5
)
# Create query engine
query_engine = RetrieverQueryEngine.from_args(
retriever=retriever,
service_context=self.service_context,
text_qa_template=self.csv_query_prompt,
response_mode=self.response_mode
)
# Execute query
response = query_engine.query(query_text)
responses.append({
"csv_id": csv_id,
"filename": metadata["filename"],
"response": response
})
# Collect source information
if hasattr(response, "source_nodes"):
for node in response.source_nodes:
sources.append({
"csv": metadata["filename"],
"content": node.node.get_content()[:100] + "..."
})
# Combine responses if multiple CSVs were queried
if len(responses) > 1:
combined_response = self._combine_responses(query_text, responses)
return {
"answer": combined_response,
"sources": sources
}
elif len(responses) == 1:
return {
"answer": responses[0]["response"],
"sources": sources
}
else:
return {
"answer": "Failed to process query with the available CSV data.",
"sources": []
}
def _combine_responses(self, query_text: str, responses: List[Dict]) -> str:
"""Combine responses from multiple CSV files."""
# Create a prompt for combining multiple CSV responses
combine_prompt = f"""
I need to answer this question: {query_text}
I've analyzed multiple CSV files and found these results:
{chr(10).join([f"From {r['filename']}: {str(r['response'])}" for r in responses])}
Please provide a unified answer that combines these insights.
"""
# Use the LLM to generate a combined response
combined_response = self.llm.complete(combine_prompt)
return combined_response.text