|
import os |
|
import gradio as gr |
|
import tempfile |
|
from pathlib import Path |
|
import base64 |
|
from PIL import Image |
|
import io |
|
import time |
|
import sys |
|
|
|
|
|
current_dir = os.path.dirname(os.path.abspath(__file__)) |
|
sys.path.append(current_dir) |
|
|
|
|
|
|
|
from models.llm_setup import setup_llm |
|
from indexes.csv_index_builder import EnhancedCSVReader |
|
from indexes.index_manager import CSVIndexManager |
|
from indexes.query_engine import CSVQueryEngine |
|
|
|
|
|
tools_dir = os.path.join(current_dir, "tools") |
|
sys.path.append(tools_dir) |
|
|
|
from tools.data_tools import PandasDataTools |
|
from tools.visualization import VisualizationTools |
|
from tools.export import ExportTools |
|
|
|
|
|
UPLOAD_DIR = Path(tempfile.mkdtemp()) |
|
EXPORT_DIR = Path(tempfile.mkdtemp()) |
|
|
|
class CSVChatApp: |
|
"""Main application class for CSV chatbot.""" |
|
|
|
def __init__(self): |
|
"""Initialize the application components.""" |
|
|
|
self.llm = setup_llm() |
|
|
|
|
|
self.index_manager = CSVIndexManager() |
|
|
|
|
|
self.data_tools = PandasDataTools(str(UPLOAD_DIR)) |
|
self.viz_tools = VisualizationTools(str(UPLOAD_DIR)) |
|
self.export_tools = ExportTools(str(EXPORT_DIR)) |
|
|
|
|
|
self.query_engine = self._setup_query_engine() |
|
|
|
|
|
self.chat_history = [] |
|
self.uploaded_files = [] |
|
|
|
def _setup_query_engine(self): |
|
"""Set up the query engine with tools.""" |
|
|
|
tools = ( |
|
self.data_tools.get_tools() + |
|
self.viz_tools.get_tools() + |
|
self.export_tools.get_tools() |
|
) |
|
|
|
|
|
query_engine = CSVQueryEngine(self.index_manager, self.llm) |
|
|
|
return query_engine |
|
|
|
def handle_file_upload(self, files): |
|
"""Process uploaded CSV files.""" |
|
file_info = [] |
|
|
|
for file in files: |
|
if file is None: |
|
continue |
|
|
|
|
|
file_path = Path(file.name) |
|
|
|
|
|
if not file_path.suffix.lower() == '.csv': |
|
continue |
|
|
|
|
|
dest_path = UPLOAD_DIR / file_path.name |
|
with open(dest_path, 'wb') as f: |
|
f.write(file_path.read_bytes()) |
|
|
|
|
|
try: |
|
self.index_manager.create_index(str(dest_path)) |
|
file_info.append(f"β
Indexed: {file_path.name}") |
|
self.uploaded_files.append(str(dest_path)) |
|
except Exception as e: |
|
file_info.append(f"β Failed to index {file_path.name}: {str(e)}") |
|
|
|
|
|
if file_info: |
|
return "\n".join(file_info) |
|
else: |
|
return "No CSV files were uploaded." |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def process_query(self, query, history): |
|
"""Process a user query and generate a response.""" |
|
if not self.uploaded_files: |
|
|
|
return history + [[query, "Please upload CSV files before asking questions."]] |
|
|
|
|
|
self.chat_history.append({"role": "user", "content": query}) |
|
|
|
|
|
try: |
|
response = self.query_engine.query(query) |
|
answer = response["answer"] |
|
|
|
|
|
if isinstance(answer, dict) and "image" in answer: |
|
|
|
img_data = answer["image"] |
|
img = Image.open(io.BytesIO(base64.b64decode(img_data))) |
|
img_path = EXPORT_DIR / f"viz_{int(time.time())}.png" |
|
img.save(img_path) |
|
|
|
|
|
text_response = answer.get("text", "Generated visualization") |
|
answer = (text_response, str(img_path)) |
|
|
|
|
|
self.chat_history.append({"role": "assistant", "content": answer}) |
|
|
|
|
|
return history + [[query, answer]] |
|
|
|
except Exception as e: |
|
error_msg = f"Error processing query: {str(e)}" |
|
self.chat_history.append({"role": "assistant", "content": error_msg}) |
|
|
|
|
|
return history + [[query, error_msg]] |
|
|
|
|
|
def export_conversation(self): |
|
"""Export the conversation as a report.""" |
|
if not self.chat_history: |
|
return "No conversation to export." |
|
|
|
|
|
title = "CSV Chat Conversation Report" |
|
content = "" |
|
images = [] |
|
|
|
for msg in self.chat_history: |
|
role = msg["role"] |
|
content_text = msg["content"] |
|
|
|
|
|
if isinstance(content_text, tuple) and len(content_text) == 2: |
|
text, img_path = content_text |
|
content += f"\n\n{'User' if role == 'user' else 'Assistant'}: {text}" |
|
|
|
|
|
try: |
|
with open(img_path, "rb") as img_file: |
|
img_data = base64.b64encode(img_file.read()).decode('utf-8') |
|
images.append(img_data) |
|
except Exception: |
|
pass |
|
else: |
|
content += f"\n\n{'User' if role == 'user' else 'Assistant'}: {content_text}" |
|
|
|
|
|
result = self.export_tools.generate_report(title, content, images) |
|
|
|
if result["success"]: |
|
return f"Report exported to: {result['report_path']}" |
|
else: |
|
return "Failed to export report." |
|
|
|
|
|
def create_interface(): |
|
"""Create the Gradio web interface.""" |
|
app = CSVChatApp() |
|
|
|
with gr.Blocks(title="CSV Chat Assistant") as interface: |
|
gr.Markdown("# CSV Chat Assistant") |
|
gr.Markdown("Upload CSV files and ask questions in natural language.") |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
file_upload = gr.File( |
|
label="Upload CSV Files", |
|
file_count="multiple", |
|
type="filepath" |
|
) |
|
upload_button = gr.Button("Process Files") |
|
file_status = gr.Textbox(label="File Status") |
|
|
|
export_button = gr.Button("Export Conversation") |
|
export_status = gr.Textbox(label="Export Status") |
|
|
|
with gr.Column(scale=2): |
|
chatbot = gr.Chatbot(label="Conversation") |
|
msg = gr.Textbox(label="Your Question") |
|
submit_button = gr.Button("Submit") |
|
|
|
|
|
upload_button.click( |
|
fn=app.handle_file_upload, |
|
inputs=[file_upload], |
|
outputs=[file_status] |
|
) |
|
|
|
submit_button.click( |
|
fn=app.process_query, |
|
inputs=[msg, chatbot], |
|
outputs=[chatbot] |
|
) |
|
|
|
export_button.click( |
|
fn=app.export_conversation, |
|
inputs=[], |
|
outputs=[export_status] |
|
) |
|
|
|
return interface |
|
|
|
|
|
if __name__ == "__main__": |
|
interface = create_interface() |
|
interface.launch() |
|
|