|
from typing import Dict, List, Optional, Any
|
|
from llama_index.query_engine import RetrieverQueryEngine
|
|
from llama_index.retrievers import VectorIndexRetriever
|
|
from llama_index.response_synthesizers import ResponseMode
|
|
from llama_index.llms import HuggingFaceLLM
|
|
from llama_index import ServiceContext, QueryBundle
|
|
from llama_index.prompts import PromptTemplate
|
|
|
|
class CSVQueryEngine:
|
|
"""Query engine for CSV data with multi-file support."""
|
|
|
|
def __init__(self, index_manager, llm, response_mode="compact"):
|
|
"""Initialize with index manager and language model."""
|
|
self.index_manager = index_manager
|
|
self.llm = llm
|
|
self.service_context = ServiceContext.from_defaults(llm=llm)
|
|
self.response_mode = response_mode
|
|
|
|
|
|
self._setup_prompts()
|
|
|
|
def _setup_prompts(self):
|
|
"""Set up custom prompts for CSV querying."""
|
|
self.csv_query_prompt = PromptTemplate(
|
|
"""You are an AI assistant specialized in analyzing CSV data.
|
|
Answer the following query using the provided CSV information.
|
|
If calculations are needed, explain your process.
|
|
|
|
CSV Context: {context_str}
|
|
Query: {query_str}
|
|
|
|
Answer:"""
|
|
)
|
|
|
|
def query(self, query_text: str) -> Dict[str, Any]:
|
|
"""Process a natural language query across CSV files."""
|
|
|
|
relevant_csvs = self.index_manager.find_relevant_csvs(query_text)
|
|
|
|
if not relevant_csvs:
|
|
return {
|
|
"answer": "No relevant CSV files found for your query.",
|
|
"sources": []
|
|
}
|
|
|
|
|
|
responses = []
|
|
sources = []
|
|
|
|
|
|
for csv_id in relevant_csvs:
|
|
index_info = self.index_manager.indexes.get(csv_id)
|
|
if not index_info:
|
|
continue
|
|
|
|
index = index_info["index"]
|
|
metadata = index_info["metadata"]
|
|
|
|
|
|
retriever = VectorIndexRetriever(
|
|
index=index,
|
|
similarity_top_k=5
|
|
)
|
|
|
|
|
|
query_engine = RetrieverQueryEngine.from_args(
|
|
retriever=retriever,
|
|
service_context=self.service_context,
|
|
text_qa_template=self.csv_query_prompt,
|
|
response_mode=self.response_mode
|
|
)
|
|
|
|
|
|
response = query_engine.query(query_text)
|
|
|
|
responses.append({
|
|
"csv_id": csv_id,
|
|
"filename": metadata["filename"],
|
|
"response": response
|
|
})
|
|
|
|
|
|
if hasattr(response, "source_nodes"):
|
|
for node in response.source_nodes:
|
|
sources.append({
|
|
"csv": metadata["filename"],
|
|
"content": node.node.get_content()[:100] + "..."
|
|
})
|
|
|
|
|
|
if len(responses) > 1:
|
|
combined_response = self._combine_responses(query_text, responses)
|
|
return {
|
|
"answer": combined_response,
|
|
"sources": sources
|
|
}
|
|
elif len(responses) == 1:
|
|
return {
|
|
"answer": responses[0]["response"],
|
|
"sources": sources
|
|
}
|
|
else:
|
|
return {
|
|
"answer": "Failed to process query with the available CSV data.",
|
|
"sources": []
|
|
}
|
|
|
|
def _combine_responses(self, query_text: str, responses: List[Dict]) -> str:
|
|
"""Combine responses from multiple CSV files."""
|
|
|
|
combine_prompt = f"""
|
|
I need to answer this question: {query_text}
|
|
|
|
I've analyzed multiple CSV files and found these results:
|
|
|
|
{chr(10).join([f"From {r['filename']}: {str(r['response'])}" for r in responses])}
|
|
|
|
Please provide a unified answer that combines these insights.
|
|
"""
|
|
|
|
|
|
combined_response = self.llm.complete(combine_prompt)
|
|
|
|
return combined_response.text
|
|
|