|
import os |
|
import tempfile |
|
import json |
|
import pandas as pd |
|
import numpy as np |
|
import gradio as gr |
|
import matplotlib.pyplot as plt |
|
import plotly.express as px |
|
import plotly.graph_objects as go |
|
from sqlalchemy import create_engine |
|
from pandasai import SmartDataframe |
|
from pandasai.llm import OpenAI |
|
import sqlite3 |
|
from dotenv import load_dotenv |
|
import atexit |
|
import base64 |
|
import io |
|
|
|
load_dotenv() |
|
|
|
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") |
|
|
|
app_instance = None |
|
|
|
class DataChatApp: |
|
def __init__(self): |
|
self.df = None |
|
self.data_source = None |
|
self.llm = OpenAI(api_token=OPENAI_API_KEY) |
|
self.smart_df = None |
|
self.chat_history = [] |
|
self.temp_files = [] |
|
self.db_connection = None |
|
global app_instance |
|
app_instance = self |
|
|
|
def load_file(self, file): |
|
"""Load data from uploaded file""" |
|
if file is None: |
|
return "No file uploaded", None, None |
|
|
|
file_path = file.name |
|
file_name = os.path.basename(file_path) |
|
file_ext = os.path.splitext(file_name)[1].lower() |
|
|
|
try: |
|
if file_ext == '.csv': |
|
self.df = pd.read_csv(file_path) |
|
elif file_ext == '.xlsx' or file_ext == '.xls': |
|
self.df = pd.read_excel(file_path) |
|
elif file_ext == '.json': |
|
self.df = pd.read_json(file_path) |
|
else: |
|
return f"Unsupported file format: {file_ext}", None, None |
|
|
|
|
|
self.smart_df = SmartDataframe(self.df, config={"llm": self.llm}) |
|
self.data_source = f"File: {file_name}" |
|
preview = self.df.head().to_html() |
|
info = self._get_dataframe_info() |
|
return f"Loaded successfully: {file_name}", preview, info |
|
except Exception as e: |
|
return f"Error loading file: {str(e)}", None, None |
|
|
|
return self.df |
|
|
|
def connect_database(self, connection_string, query): |
|
"""Connect to database using connection string""" |
|
try: |
|
if connection_string.startswith('sqlite:'): |
|
if 'memory' in connection_string: |
|
self.db_connection = sqlite3.connect(':memory:') |
|
else: |
|
db_path = connection_string.replace('sqlite:///', '') |
|
self.db_connection = sqlite3.connect(db_path) |
|
else: |
|
self.db_connection = create_engine(connection_string) |
|
|
|
if not query: |
|
return "Please provide a SQL query", None, None |
|
|
|
self.df = pd.read_sql(query, self.db_connection) |
|
self.smart_df = SmartDataframe(self.df, config={"llm": self.llm}) |
|
self.data_source = f"Database: {connection_string.split('://')[0]}" |
|
preview = self.df.head().to_html() |
|
info = self._get_dataframe_info() |
|
return "Database connected successfully", preview, info |
|
except Exception as e: |
|
return f"Database connection error: {str(e)}", None, None |
|
|
|
return self.df |
|
|
|
def _get_dataframe_info(self): |
|
"""Get information about the dataframe""" |
|
if self.df is None: |
|
return None |
|
|
|
info = { |
|
"Shape": self.df.shape, |
|
"Columns": list(self.df.columns), |
|
"Data Types": {col: str(dtype) for col, dtype in self.df.dtypes.items()}, |
|
"Missing Values": self.df.isnull().sum().to_dict() |
|
} |
|
return json.dumps(info, indent=2) |
|
|
|
def chat_with_data(self, query, history): |
|
"""Process natural language query against the loaded data""" |
|
if self.df is None or self.smart_df is None: |
|
return "Please load data first before querying.", history |
|
|
|
if not query: |
|
return "Please enter a query.", history |
|
|
|
try: |
|
if history is None: |
|
history = [] |
|
|
|
response = self.smart_df.chat(query) |
|
|
|
if isinstance(response, plt.Figure): |
|
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png') |
|
response.savefig(temp_file.name) |
|
temp_file.close() |
|
self.temp_files.append(temp_file.name) |
|
|
|
response_text = f"<img src='file={temp_file.name}' alt='Visualization' />" |
|
|
|
elif isinstance(response, pd.DataFrame): |
|
response_text = f"<div style='overflow-x: auto;'>{response.to_html(index=False)}</div>" |
|
else: |
|
response_text = str(response) |
|
|
|
history.append({"role": "user", "content": query}) |
|
history.append({"role": "assistant", "content": response_text}) |
|
|
|
return "", history |
|
except Exception as e: |
|
if not history: |
|
history = [] |
|
history.append({"role": "user", "content": query}) |
|
history.append({"role": "assistant", "content": f"Error processing query: {str(e)}"}) |
|
return "", history |
|
|
|
def create_visualization(self, viz_type, x_axis, y_axis, title): |
|
"""Create visualization based on user selection""" |
|
if self.df is None: |
|
return "Please load data first before creating visualizations." |
|
|
|
if not x_axis or (viz_type != 'pie' and viz_type != 'histogram' and not y_axis): |
|
return "Please select both X and Y axis for the visualization." |
|
|
|
try: |
|
if x_axis not in self.df.columns: |
|
return f"Column '{x_axis}' not found in the data." |
|
|
|
if viz_type != 'pie' and viz_type != 'histogram' and y_axis not in self.df.columns: |
|
return f"Column '{y_axis}' not found in the data." |
|
|
|
plt.figure(figsize=(10, 6)) |
|
|
|
if viz_type == 'bar': |
|
plt.bar(self.df[x_axis], self.df[y_axis]) |
|
plt.xlabel(x_axis) |
|
plt.ylabel(y_axis) |
|
plt.title(title or f"Bar Chart: {y_axis} by {x_axis}") |
|
|
|
elif viz_type == 'line': |
|
plt.plot(self.df[x_axis], self.df[y_axis]) |
|
plt.xlabel(x_axis) |
|
plt.ylabel(y_axis) |
|
plt.title(title or f"Line Chart: {y_axis} over {x_axis}") |
|
|
|
elif viz_type == 'scatter': |
|
plt.scatter(self.df[x_axis], self.df[y_axis]) |
|
plt.xlabel(x_axis) |
|
plt.ylabel(y_axis) |
|
plt.title(title or f"Scatter Plot: {y_axis} vs {x_axis}") |
|
|
|
elif viz_type == 'pie': |
|
if y_axis and y_axis in self.df.columns: |
|
pie_data = self.df.groupby(x_axis)[y_axis].sum() |
|
plt.pie(pie_data, labels=pie_data.index, autopct='%1.1f%%') |
|
else: |
|
counts = self.df[x_axis].value_counts() |
|
plt.pie(counts, labels=counts.index, autopct='%1.1f%%') |
|
plt.title(title or f"Pie Chart: Distribution of {x_axis}") |
|
|
|
elif viz_type == 'histogram': |
|
plt.hist(self.df[x_axis], bins=20) |
|
plt.xlabel(x_axis) |
|
plt.ylabel('Frequency') |
|
plt.title(title or f"Histogram: Distribution of {x_axis}") |
|
|
|
plt.tight_layout() |
|
|
|
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png') |
|
plt.savefig(temp_file.name, dpi=100, bbox_inches='tight') |
|
temp_file.close() |
|
self.temp_files.append(temp_file.name) |
|
|
|
with open(temp_file.name, 'rb') as img_file: |
|
img_data = base64.b64encode(img_file.read()).decode('utf-8') |
|
|
|
html_content = f""" |
|
<div style="text-align: center; padding: 20px; background-color: white; border-radius: 10px; box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);"> |
|
<img src="data:image/png;base64,{img_data}" style="max-width: 100%; height: auto;" alt="Visualization"> |
|
</div> |
|
""" |
|
|
|
plt.close() |
|
|
|
return html_content |
|
|
|
except Exception as e: |
|
plt.close() |
|
return f"Error creating visualization: {str(e)}" |
|
|
|
def generate_summary_cards(self): |
|
"""Generate summary cards (KPIs) for numerical columns""" |
|
if self.df is None: |
|
return "Please load data first before generating summary cards." |
|
|
|
try: |
|
num_cols = self.df.select_dtypes(include=[np.number]).columns.tolist() |
|
|
|
if not num_cols: |
|
return "No numerical columns found for summary cards." |
|
|
|
cards_html = """ |
|
<style> |
|
.summary-card { |
|
background-color: #f5f5f5; |
|
border-radius: 5px; |
|
padding: 15px; |
|
min-width: 200px; |
|
box-shadow: 0 2px 4px rgba(0,0,0,0.1); |
|
margin: 10px; |
|
} |
|
.summary-card h3 { |
|
margin-top: 0; |
|
color: #333 !important; |
|
font-weight: bold; |
|
} |
|
.summary-card p { |
|
color: #333 !important; |
|
margin: 8px 0; |
|
} |
|
.summary-card strong { |
|
font-weight: bold; |
|
color: #333 !important; |
|
} |
|
.summary-container { |
|
display: flex; |
|
flex-wrap: wrap; |
|
gap: 10px; |
|
} |
|
</style> |
|
<div class="summary-container"> |
|
""" |
|
|
|
for col in num_cols: |
|
mean_val = self.df[col].mean() |
|
median_val = self.df[col].median() |
|
min_val = self.df[col].min() |
|
max_val = self.df[col].max() |
|
|
|
card_html = f""" |
|
<div class="summary-card"> |
|
<h3>{col}</h3> |
|
<p><strong>Mean:</strong> {mean_val:.2f}</p> |
|
<p><strong>Median:</strong> {median_val:.2f}</p> |
|
<p><strong>Min:</strong> {min_val:.2f}</p> |
|
<p><strong>Max:</strong> {max_val:.2f}</p> |
|
</div> |
|
""" |
|
cards_html += card_html |
|
|
|
cards_html += "</div>" |
|
return cards_html |
|
|
|
except Exception as e: |
|
return f"Error generating summary cards: {str(e)}" |
|
|
|
def cleanup(self): |
|
"""Clean up temporary files""" |
|
for file in self.temp_files: |
|
try: |
|
if os.path.exists(file): |
|
os.unlink(file) |
|
except Exception: |
|
pass |
|
|
|
if self.db_connection is not None: |
|
try: |
|
if hasattr(self.db_connection, 'close'): |
|
self.db_connection.close() |
|
elif hasattr(self.db_connection, 'dispose'): |
|
self.db_connection.dispose() |
|
except Exception: |
|
pass |
|
|
|
def create_interface(): |
|
app = DataChatApp() |
|
|
|
def update_column_options(): |
|
if app_instance and app_instance.df is not None: |
|
return gr.update(choices=list(app_instance.df.columns)) |
|
return gr.update(choices=[]) |
|
|
|
with gr.Blocks(theme=gr.themes.Soft(), title="Data Chat App", css=""" |
|
.plot-container {width: 100% !important; height: 100% !important;} |
|
.js-plotly-plot {min-height: 500px;} |
|
.plotly {min-height: 500px;} |
|
""") as interface: |
|
gr.Markdown(""" |
|
# GIN Data Chat Application |
|
Upload your data file or connect to a database, then chat with your data using natural language! |
|
""") |
|
|
|
with gr.Tabs(): |
|
with gr.TabItem("Load Data"): |
|
with gr.Tab("File Upload"): |
|
file_input = gr.File(label="Upload CSV, Excel, or JSON file") |
|
file_upload_button = gr.Button("Load File") |
|
file_result = gr.Textbox(label="Result") |
|
|
|
with gr.Tab("Database Connection"): |
|
conn_str = gr.Textbox( |
|
label="Connection String", |
|
placeholder="E.g., sqlite:///data.db, postgresql://user:pass@localhost/db" |
|
) |
|
query = gr.Textbox( |
|
label="SQL Query", |
|
placeholder="SELECT * FROM your_table LIMIT 1000" |
|
) |
|
db_connect_button = gr.Button("Connect to Database") |
|
db_result = gr.Textbox(label="Result") |
|
|
|
preview = gr.HTML(label="Data Preview") |
|
info = gr.JSON(label="Data Information") |
|
|
|
with gr.TabItem("Chat with Data"): |
|
chat_interface = gr.Chatbot(height=400, type="messages") |
|
query_input = gr.Textbox( |
|
label="Ask a question about your data", |
|
placeholder="E.g., Show me the trend of sales over time", |
|
lines=2 |
|
) |
|
chat_button = gr.Button("Ask") |
|
|
|
with gr.TabItem("Visualize Data"): |
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
viz_type = gr.Dropdown( |
|
choices=["bar", "line", "scatter", "pie", "histogram"], |
|
label="Visualization Type", |
|
value="bar" |
|
) |
|
x_axis = gr.Dropdown(label="X-Axis / Category") |
|
y_axis = gr.Dropdown(label="Y-Axis / Values (Optional for Pie & Histogram)") |
|
viz_title = gr.Textbox(label="Chart Title (Optional)") |
|
viz_button = gr.Button("Generate Visualization", variant="primary") |
|
|
|
with gr.Column(scale=2): |
|
viz_output = gr.HTML(label="Visualization", value="<div style='width:100%; height:500px; display:flex; justify-content:center; align-items:center; color:#666; font-size:16px;'>Your visualization will appear here</div>") |
|
|
|
with gr.TabItem("Summary Stats"): |
|
summary_button = gr.Button("Generate Summary Cards") |
|
summary_output = gr.HTML(label="Summary Statistics") |
|
|
|
|
|
file_upload_button.click( |
|
app.load_file, |
|
inputs=[file_input], |
|
outputs=[file_result, preview, info] |
|
).then( |
|
update_column_options, |
|
inputs=None, |
|
outputs=[x_axis] |
|
).then( |
|
update_column_options, |
|
inputs=None, |
|
outputs=[y_axis] |
|
) |
|
|
|
db_connect_button.click( |
|
app.connect_database, |
|
inputs=[conn_str, query], |
|
outputs=[db_result, preview, info] |
|
).then( |
|
update_column_options, |
|
inputs=None, |
|
outputs=[x_axis] |
|
).then( |
|
update_column_options, |
|
inputs=None, |
|
outputs=[y_axis] |
|
) |
|
|
|
chat_button.click( |
|
app.chat_with_data, |
|
inputs=[query_input, chat_interface], |
|
outputs=[query_input, chat_interface] |
|
) |
|
|
|
query_input.submit( |
|
app.chat_with_data, |
|
inputs=[query_input, chat_interface], |
|
outputs=[query_input, chat_interface] |
|
) |
|
|
|
|
|
viz_button.click( |
|
app.create_visualization, |
|
inputs=[viz_type, x_axis, y_axis, viz_title], |
|
outputs=[viz_output] |
|
) |
|
|
|
summary_button.click( |
|
app.generate_summary_cards, |
|
outputs=[summary_output] |
|
) |
|
|
|
|
|
|
|
|
|
app.cleanup() |
|
|
|
return interface |
|
|
|
if __name__ == "__main__": |
|
import atexit |
|
app = DataChatApp() |
|
atexit.register(app.cleanup) |
|
|
|
interface = create_interface() |
|
interface.launch(share=True) |