|
from typing import Dict, List, Any, Optional |
|
import pandas as pd |
|
from sklearn.metrics.pairwise import cosine_similarity |
|
import numpy as np |
|
|
|
class CSVQueryEngine: |
|
"""Query engine for CSV data with multi-file support.""" |
|
|
|
def __init__(self, index_manager, llm): |
|
"""Initialize with index manager and language model.""" |
|
self.index_manager = index_manager |
|
self.llm = llm |
|
|
|
def query(self, query_text: str) -> Dict[str, Any]: |
|
"""Process a natural language query across CSV files.""" |
|
|
|
relevant_csvs = self.index_manager.find_relevant_csvs(query_text) |
|
|
|
if not relevant_csvs: |
|
return { |
|
"answer": "No relevant CSV files found for your query.", |
|
"sources": [] |
|
} |
|
|
|
|
|
context = self._prepare_context(query_text, relevant_csvs) |
|
|
|
|
|
prompt = self._generate_prompt(query_text, context) |
|
|
|
|
|
response = self.llm.complete(prompt) |
|
|
|
|
|
return { |
|
"answer": response.text, |
|
"sources": self._get_sources(relevant_csvs) |
|
} |
|
|
|
def _prepare_context(self, query: str, csv_ids: List[str]) -> str: |
|
"""Prepare context from relevant CSV files with pre-calculated statistics.""" |
|
context_parts = [] |
|
calculated_answers = {} |
|
|
|
|
|
query_lower = query.lower() |
|
is_avg_question = "average" in query_lower or "mean" in query_lower |
|
is_max_question = "maximum" in query_lower or "max" in query_lower |
|
is_min_question = "minimum" in query_lower or "min" in query_lower |
|
|
|
|
|
query_words = set(query_lower.replace("?", "").replace(",", "").split()) |
|
|
|
for csv_id in csv_ids: |
|
|
|
if csv_id not in self.index_manager.indexes: |
|
continue |
|
|
|
metadata = self.index_manager.indexes[csv_id]["metadata"] |
|
file_path = self.index_manager.indexes[csv_id]["path"] |
|
|
|
|
|
context_parts.append(f"CSV File: {metadata['filename']}") |
|
context_parts.append(f"Columns: {', '.join(metadata['columns'])}") |
|
context_parts.append(f"Row Count: {metadata['row_count']}") |
|
|
|
|
|
try: |
|
df = pd.read_csv(file_path) |
|
context_parts.append("\nSample Data:") |
|
context_parts.append(df.head(3).to_string()) |
|
|
|
|
|
column_matches = [] |
|
for col in df.columns: |
|
col_lower = col.lower() |
|
|
|
if col_lower in query_lower or any(word in col_lower for word in query_words): |
|
column_matches.append(col) |
|
|
|
|
|
if not column_matches: |
|
column_matches = df.select_dtypes(include=['number']).columns.tolist() |
|
|
|
|
|
for col in column_matches: |
|
if pd.api.types.is_numeric_dtype(df[col]): |
|
if is_avg_question: |
|
avg_value = df[col].mean() |
|
context_parts.append(f"\nThe average {col} is: {avg_value:.2f}") |
|
calculated_answers[f"average_{col}"] = avg_value |
|
|
|
if is_max_question: |
|
max_value = df[col].max() |
|
context_parts.append(f"\nThe maximum {col} is: {max_value}") |
|
calculated_answers[f"max_{col}"] = max_value |
|
|
|
if is_min_question: |
|
min_value = df[col].min() |
|
context_parts.append(f"\nThe minimum {col} is: {min_value}") |
|
calculated_answers[f"min_{col}"] = min_value |
|
|
|
except Exception as e: |
|
context_parts.append(f"Error reading CSV: {str(e)}") |
|
|
|
|
|
if calculated_answers: |
|
context_parts.append("\nDirect Answer:") |
|
for key, value in calculated_answers.items(): |
|
context_parts.append(f"{key.replace('_', ' ')}: {value}") |
|
|
|
return "\n\n".join(context_parts) |
|
|
|
def _prepare_context1(self, query: str, csv_ids: List[str]) -> str: |
|
"""Prepare context from relevant CSV files.""" |
|
context_parts = [] |
|
|
|
for csv_id in csv_ids: |
|
|
|
if csv_id not in self.index_manager.indexes: |
|
continue |
|
|
|
metadata = self.index_manager.indexes[csv_id]["metadata"] |
|
file_path = self.index_manager.indexes[csv_id]["path"] |
|
|
|
|
|
context_parts.append(f"CSV File: {metadata['filename']}") |
|
context_parts.append(f"Columns: {', '.join(metadata['columns'])}") |
|
context_parts.append(f"Row Count: {metadata['row_count']}") |
|
|
|
|
|
try: |
|
df = pd.read_csv(file_path) |
|
context_parts.append("\nSample Data:") |
|
context_parts.append(df.head(5).to_string()) |
|
|
|
|
|
context_parts.append("\nNumeric Column Statistics:") |
|
numeric_cols = df.select_dtypes(include=['number']).columns |
|
for col in numeric_cols: |
|
stats = df[col].describe() |
|
context_parts.append(f"{col} - mean: {stats['mean']:.2f}, min: {stats['min']:.2f}, max: {stats['max']:.2f}") |
|
except Exception as e: |
|
context_parts.append(f"Error reading CSV: {str(e)}") |
|
|
|
return "\n\n".join(context_parts) |
|
|
|
def _generate_prompt(self, query: str, context: str) -> str: |
|
"""Generate a prompt for the LLM.""" |
|
return f"""You are an AI assistant specialized in analyzing CSV data. |
|
Your goal is to help users understand their data and extract insights. |
|
|
|
Below is information about CSV files that might help answer the query: |
|
|
|
{context} |
|
|
|
User Query: {query} |
|
|
|
Please provide a comprehensive and accurate answer based on the data. |
|
If calculations are needed, explain your process. |
|
If the data doesn't contain information to answer the query, say so clearly. |
|
|
|
Answer:""" |
|
|
|
def _get_sources(self, csv_ids: List[str]) -> List[Dict[str, str]]: |
|
"""Get source information for the response.""" |
|
sources = [] |
|
|
|
for csv_id in csv_ids: |
|
if csv_id not in self.index_manager.indexes: |
|
continue |
|
|
|
metadata = self.index_manager.indexes[csv_id]["metadata"] |
|
sources.append({ |
|
"csv": metadata["filename"], |
|
"columns": ", ".join(metadata["columns"][:5]) + ("..." if len(metadata["columns"]) > 5 else "") |
|
}) |
|
|
|
return sources |
|
|