pavanmutha's picture
Update app.py
172385c verified
# Initialization and Imports
import os
import re
import gradio as gr
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import shap
import lime.lime_tabular
import optuna
import wandb
import json
import time
import psutil
import shutil
import ast
from smolagents import HfApiModel, CodeAgent
from huggingface_hub import login
from sklearn.model_selection import train_test_split, cross_val_score
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from sklearn.metrics import ConfusionMatrixDisplay
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import LabelEncoder
from datetime import datetime
from PIL import Image
# Authenticate with Hugging Face
hf_token = os.getenv("HF_TOKEN")
login(token=hf_token)
# SmolAgent initialization
model = HfApiModel("mistralai/Mixtral-8x7B-Instruct-v0.1", token=hf_token)
# Globals
df_global = None
target_column_global = None
#File Upload and Cleanup
def upload_file(file):
global df_global, data_summary_global
if file is None:
return pd.DataFrame({"Error": ["No file uploaded."]}), gr.update(choices=[])
ext = os.path.splitext(file.name)[-1]
df = pd.read_csv(file.name) if ext == ".csv" else pd.read_excel(file.name)
df = clean_data(df)
df_global = df
return df.head(), gr.update(choices=df.columns.tolist())
def set_target_column(col_name):
global target_column_global
target_column_global = col_name
return f"✅ Target column set to: {col_name}"
def clean_data(df):
from sklearn.preprocessing import LabelEncoder
import numpy as np
# Drop completely empty rows/columns
df = df.dropna(how='all', axis=1).dropna(how='all', axis=0)
# Sanitize 'Amount' or similar money/number-looking columns
for col in df.columns:
if df[col].dtype == 'object':
# Attempt cleaning for common currency/number strings
try:
cleaned = df[col].str.replace(r'[$,]', '', regex=True).str.strip()
df[col] = pd.to_numeric(cleaned, errors='ignore') # Keep original if conversion fails
except Exception:
pass
# Encode any remaining object-type columns
for col in df.select_dtypes(include='object').columns:
try:
df[col] = df[col].astype(str)
df[col] = LabelEncoder().fit_transform(df[col])
except Exception:
pass
# Fill remaining NaNs
df = df.fillna(df.mean(numeric_only=True))
return df
# Add a extraction of JSON if CodeAgent Output is not in format
import json
import re
import ast
def extract_json_from_codeagent_output(raw_output):
try:
# Case 1: If it's already a dict
if isinstance(raw_output, dict):
# If there's a stringified JSON inside a dict key like 'output'
if "output" in raw_output and isinstance(raw_output["output"], str):
try:
return json.loads(raw_output["output"])
except json.JSONDecodeError:
pass # Not JSON inside
return raw_output
# Case 2: Try parsing the whole string as JSON
if isinstance(raw_output, str):
try:
return json.loads(raw_output)
except json.JSONDecodeError:
pass # fallback to deeper extraction
# Case 3: Extract code blocks (supports json/py/python/empty labels)
code_blocks = re.findall(r"```(?:json|py|python)?\n([\s\S]*?)```", raw_output, re.DOTALL)
for block in code_blocks:
for pattern in [
r"print\(\s*json\.dumps\(\s*(\{[\s\S]*?\})\s*\)\s*\)",
r"json\.dumps\(\s*(\{[\s\S]*?\})\s*\)",
r"result\s*=\s*(\{[\s\S]*?\})",
r"final_answer\s*\(\s*(\{[\s\S]*?\})\s*\)",
r"^(\{[\s\S]*\})$" # Direct raw JSON block
]:
match = re.search(pattern, block, re.DOTALL)
if match:
try:
return json.loads(match.group(1))
except json.JSONDecodeError:
return ast.literal_eval(match.group(1))
# Case 4: Final fallback - any dict-like structure anywhere in output
fallback = re.search(r"\{[\s\S]+?\}", raw_output)
if fallback:
try:
return json.loads(fallback.group(0))
except json.JSONDecodeError:
return ast.literal_eval(fallback.group(0))
except Exception as e:
print(f"[extract_json] Error: {e}")
# Case 5: If everything fails
return {"error": "Failed to extract structured JSON"}
import pandas as pd
import tempfile
def analyze_data(csv_file, additional_notes=""):
start_time = time.time()
process = psutil.Process(os.getpid())
initial_memory = process.memory_info().rss / 1024 ** 2
# Clean the uploaded CSV file
try:
df = pd.read_csv(csv_file)
df = clean_data(df)
except Exception as e:
return f"<p style='color:red'><b>Error loading or cleaning CSV:</b> {e}</p>", []
# Save cleaned CSV to disk (using a stable location)
cleaned_csv_path = "./cleaned_data.csv"
df.to_csv(cleaned_csv_path, index=False)
# Clear or create figures folder
if os.path.exists('./figures'):
shutil.rmtree('./figures')
os.makedirs('./figures', exist_ok=True)
# Initialize WandB
wandb.login(key=os.environ.get('WANDB_API_KEY'))
run = wandb.init(project="huggingface-data-analysis", config={
"model": "mistralai/Mixtral-8x7B-Instruct-v0.1",
"additional_notes": additional_notes,
"source_file": cleaned_csv_path
})
# CodeAgent instance
agent = CodeAgent(
tools=[],
model=model,
additional_authorized_imports=[
"numpy", "pandas", "matplotlib.pyplot", "seaborn", "sklearn", "json"
],
max_steps =8
)
# Run agent on cleaned CSV
raw_output = agent.run("""
You are a data analysis agent.Follow these instructions EXACT order:
1. Load the data from the given `source_file` ONLY. DO NOT create your OWN DATA.
2. Analyze the data and generate up to 3 clear insight and 3 visualization
3. Save all figures to `./figures` as PNG using matplotlib or seaborn.
4. Use only authorized imports: `pandas`, `numpy`, `matplotlib.pyplot`, `seaborn`, `json`.
5. DO NOT return any explanations, thoughts, or narration outside the final JSON block
6. Run only 5 steps and return output in less than a minute.
7. ONLY include natural language as observation value or insight value.
8. ONLY output a single, valid JSON block. No markdown or extra text.
9. Output ONLY the following JSON code block format, exactly:
{
'observations': {
'observation_1_key': 'observation_1_value',
...
},
'insights': {
'insight_1_key': 'insight_1_value',
...
}
}
""", additional_args={"additional_notes": additional_notes, "source_file": cleaned_csv_path})
if isinstance(raw_output, dict) and "output" in raw_output:
print(f"Raw output: {raw_output['output'][:1000]}")
else:
print(f"Raw output: {str(raw_output)[:1000]}")
# Parse output
parsed_result = extract_json_from_codeagent_output(raw_output) or {
"error": "Failed to extract structured JSON"
}
# Log execution stats
execution_time = time.time() - start_time
final_memory = process.memory_info().rss / 1024 ** 2
memory_usage = final_memory - initial_memory
wandb.log({
"execution_time_sec": round(execution_time, 2),
"memory_usage_mb": round(memory_usage, 2)
})
# Upload any figures
visuals = [os.path.join('./figures', f) for f in os.listdir('./figures') if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
for viz in visuals:
wandb.log({os.path.basename(viz): wandb.Image(viz)})
run.finish()
# HTML Summary
summary_html = "<h3>📊 Data Analysis Summary</h3>"
if "observations" in parsed_result:
summary_html += "<h4>🔍 Observations</h4><ul>" + "".join(
f"<li><b>{k}:</b> {v}</li>" for k, v in parsed_result["observations"].items()
) + "</ul>"
if "insights" in parsed_result:
summary_html += "<h4>💡 Insights</h4><ul>" + "".join(
f"<li><b>{k}:</b> {v}</li>" for k, v in parsed_result["insights"].items()
) + "</ul>"
if "error" in parsed_result:
summary_html += f"<p style='color:red'><b>Error:</b> {parsed_result['error']}</p>"
return summary_html, visuals
def format_analysis_report(raw_output, visuals):
import json
try:
if isinstance(raw_output, dict):
analysis_dict = raw_output
else:
try:
analysis_dict = json.loads(str(raw_output))
except (json.JSONDecodeError, TypeError) as e:
print(f"Error parsing CodeAgent output: {e}")
return f"<pre>{str(raw_output)}</pre>", visuals
report = f"""
<div style="font-family: Arial, sans-serif; padding: 20px; color: #333;">
<h1 style="color: #2B547E; border-bottom: 2px solid #2B547E; padding-bottom: 10px;">📊 Data Analysis Report</h1>
<div style="margin-top: 25px; background: #f8f9fa; padding: 20px; border-radius: 8px;">
<h2 style="color: #2B547E;">🔍 Key Observations</h2>
{format_observations(analysis_dict.get('observations', {}))}
</div>
<div style="margin-top: 30px;">
<h2 style="color: #2B547E;">💡 Insights & Visualizations</h2>
{format_insights(analysis_dict.get('insights', {}), visuals)}
</div>
</div>
"""
return report, visuals
except Exception as e:
print(f"Error in format_analysis_report: {e}")
return f"<pre>{str(raw_output)}</pre>", visuals
def format_observations(observations):
return '\n'.join([
f"""
<div style="margin: 15px 0; padding: 15px; background: white; border-radius: 8px; box-shadow: 0 2px 4px rgba(0,0,0,0.05);">
<h3 style="margin: 0 0 10px 0; color: #4A708B;">{key.replace('_', ' ').title()}</h3>
<pre style="
margin: 0;
padding: 10px;
background: #eef2f7;
border-radius: 4px;
color: #1f2d3d;
font-size: 14px;
font-family: 'Courier New', Courier, monospace;
white-space: pre-wrap;
opacity: 1;
">{value}</pre>
</div>
""" for key, value in observations.items()
])
def format_insights(insights, visuals):
if isinstance(insights, dict):
# Old format (dict of key: text)
insight_items = list(insights.items())
elif isinstance(insights, list):
# New format (list of dicts with "insight" and optional "category")
insight_items = [(item.get("category", f"Insight {idx+1}"), item["insight"]) for idx, item in enumerate(insights)]
else:
return "<p>No insights available or incorrect format.</p>"
return '\n'.join([
f"""
<div style="margin: 20px 0; padding: 20px; background: white; border-radius: 8px; box-shadow: 0 2px 4px rgba(0,0,0,0.05);">
<div style="display: flex; align-items: center; gap: 10px;">
<div style="background: #2B547E; color: white; width: 30px; height: 30px; border-radius: 50%; display: flex; align-items: center; justify-content: center;">{idx+1}</div>
<div>
<h4 style="margin: 0; color: #2B547E;">{title}</h4>
<p style="margin: 5px 0 0 0; font-size: 16px; color: #333; font-weight: 500;">{insight}</p>
</div>
</div>
{f'<img src="file/{os.path.basename(visuals[idx])}" style="max-width: 100%; height: auto; margin-top: 10px; border-radius: 6px; box-shadow: 0 2px 4px rgba(0,0,0,0.1);">' if idx < len(visuals) else ''}
</div>
""" for idx, (title, insight) in enumerate(insight_items)
])
from sklearn.model_selection import StratifiedKFold, GridSearchCV
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier, VotingClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import f1_score, precision_score, recall_score
import optuna
def compare_models():
import seaborn as sns
from sklearn.model_selection import cross_val_predict, cross_val_score
if df_global is None:
return pd.DataFrame({"Error": ["Please upload and preprocess a dataset first."]}), None
global target_column_global
target = target_column_global
X = df_global.drop(target, axis=1)
y = df_global[target]
# If the target is categorical, encode it
if y.dtype == 'object':
y = LabelEncoder().fit_transform(y)
# Scale features for models like Logistic Regression
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)
# Define models
models = {
"RandomForest": RandomForestClassifier(),
"LogisticRegression": LogisticRegression(max_iter=1000),
"GradientBoosting": GradientBoostingClassifier(),
# Consider adding more models like XGBoost
}
# Optionally, define an ensemble method
ensemble_model = VotingClassifier(estimators=[('rf', RandomForestClassifier()),
('lr', LogisticRegression(max_iter=1000)),
('gb', GradientBoostingClassifier())], voting='hard')
# Adding the ensemble model to the list
models["Voting Classifier"] = ensemble_model
results = []
for name, model in models.items():
# Cross-validation scores
scores = cross_val_score(model, X_scaled, y, cv=5)
# Cross-validated predictions for metrics
y_pred = cross_val_predict(model, X_scaled, y, cv=5)
metrics = {
"Model": name,
"CV Mean Accuracy": np.mean(scores),
"CV Std Dev": np.std(scores),
"F1 Score": f1_score(y, y_pred, average="weighted", zero_division=0),
"Precision": precision_score(y, y_pred, average="weighted", zero_division=0),
"Recall": recall_score(y, y_pred, average="weighted", zero_division=0),
}
# Log results to WandB
if wandb.run is None:
wandb.init(project="model_comparison", name="compare_models", reinit=True)
wandb.log({f"{name}_{k.replace(' ', '_').lower()}": v for k, v in metrics.items() if isinstance(v, (float, int))})
results.append(metrics)
results_df = pd.DataFrame(results)
# Plotting
plt.figure(figsize=(8, 5))
sns.barplot(data=results_df, x="Model", y="CV Mean Accuracy", palette="Blues_d")
plt.title("Model Comparison (CV Mean Accuracy)")
plt.ylim(0, 1)
plt.tight_layout()
plot_path = "./model_comparison.png"
plt.savefig(plot_path)
plt.close()
return results_df, plot_path
# 1. prepare_data should come first
def prepare_data(df):
global target_column_global
from sklearn.model_selection import train_test_split
# If no target column is specified, select the first object column or the last column
if target_column_global is None:
raise ValueError("Target column not set.")
X = df.drop(columns=[target_column_global])
y = df[target_column_global]
return train_test_split(X, y, test_size=0.3, random_state=42)
def train_model(_):
try:
wandb.login(key=os.environ.get("WANDB_API_KEY"))
wandb_run = wandb.init(
project="huggingface-data-analysis",
name=f"Optuna_Run_{datetime.now().strftime('%Y%m%d_%H%M%S')}",
reinit=True
)
X_train, X_test, y_train, y_test = prepare_data(df_global)
def objective(trial):
params = {
"n_estimators": trial.suggest_int("n_estimators", 50, 200),
"max_depth": trial.suggest_int("max_depth", 3, 10),
}
model = RandomForestClassifier(**params)
score = cross_val_score(model, X_train, y_train, cv=3).mean()
wandb.log({**params, "cv_score": score})
return score # ✅ Must be returned here
study = optuna.create_study(direction="maximize")
study.optimize(objective, n_trials=15)
best_params = study.best_params
model = RandomForestClassifier(**best_params)
model.fit(X_train, y_train)
y_pred = model.predict(X_test)
metrics = {
"accuracy": accuracy_score(y_test, y_pred),
"precision": precision_score(y_test, y_pred, average="weighted", zero_division=0),
"recall": recall_score(y_test, y_pred, average="weighted", zero_division=0),
"f1_score": f1_score(y_test, y_pred, average="weighted", zero_division=0),
}
wandb.log(metrics)
wandb_run.finish()
# Top 7 trials
top_trials = sorted(study.trials, key=lambda x: x.value, reverse=True)[:7]
trial_rows = [dict(**t.params, score=t.value) for t in top_trials]
trials_df = pd.DataFrame(trial_rows)
return metrics, trials_df
except Exception as e:
print(f"Training Error: {e}")
return {}, pd.DataFrame()
def explainability(_):
import warnings
warnings.filterwarnings("ignore")
global target_column_global
target = target_column_global
X = df_global.drop(target, axis=1)
y = df_global[target]
if y.dtype == "object":
y = LabelEncoder().fit_transform(y)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)
model = RandomForestClassifier()
model.fit(X_train, y_train)
explainer = shap.TreeExplainer(model)
shap_values = explainer.shap_values(X_test)
try:
if isinstance(shap_values, list):
class_idx = 0
sv = shap_values[class_idx]
else:
sv = shap_values
# Ensure 2D input shape for SHAP plot
if len(sv.shape) > 2:
sv = sv.reshape(sv.shape[0], -1) # Flatten any extra dimensions
# Use safe feature names if mismatch, fallback to dummy
num_features = sv.shape[1]
if num_features <= X_test.shape[1]:
feature_names = X_test.columns[:num_features]
else:
feature_names = [f"Feature_{i}" for i in range(num_features)]
X_shap_safe = pd.DataFrame(np.zeros_like(sv), columns=feature_names)
shap.summary_plot(sv, X_shap_safe, show=False)
shap_path = "./shap_plot.png"
plt.title("SHAP Summary")
plt.savefig(shap_path)
if wandb.run:
wandb.log({"shap_summary": wandb.Image(shap_path)})
plt.clf()
except Exception as e:
shap_path = "./shap_error.png"
print("SHAP plotting failed:", e)
plt.figure(figsize=(6, 3))
plt.text(0.5, 0.5, f"SHAP Error:\n{str(e)}", ha='center', va='center')
plt.axis('off')
plt.savefig(shap_path)
if wandb.run:
wandb.log({"shap_error": wandb.Image(shap_path)})
plt.clf()
# LIME
lime_explainer = lime.lime_tabular.LimeTabularExplainer(
X_train.values,
feature_names=X_train.columns.tolist(),
class_names=[str(c) for c in np.unique(y_train)],
mode='classification'
)
lime_exp = lime_explainer.explain_instance(X_test.iloc[0].values, model.predict_proba)
lime_fig = lime_exp.as_pyplot_figure()
lime_path = "./lime_plot.png"
lime_fig.savefig(lime_path)
if wandb.run:
wandb.log({"lime_explanation": wandb.Image(lime_path)})
plt.clf()
return shap_path, lime_path
# Define this BEFORE the Gradio app layout
def update_target_choices():
global df_global
if df_global is not None:
return gr.update(choices=df_global.columns.tolist())
else:
return gr.update(choices=[])
with gr.Blocks() as demo:
gr.Markdown("## 📊 AI-Powered Data Analysis with Hyperparameter Optimization")
with gr.Row():
with gr.Column():
file_input = gr.File(label="Upload CSV or Excel", type="filepath")
df_output = gr.DataFrame(label="Cleaned Data Preview")
target_dropdown = gr.Dropdown(label="Select Target Column", choices=[], interactive=True)
target_status = gr.Textbox(label="Target Column Status", interactive=False)
file_input.change(fn=upload_file, inputs=file_input, outputs=[df_output, target_dropdown])
#file_input.change(fn=update_target_choices, inputs=[], outputs=target_dropdown)
target_dropdown.change(fn=set_target_column, inputs=target_dropdown, outputs=target_status)
with gr.Column():
insights_output = gr.HTML(label="Insights from SmolAgent")
visual_output = gr.Gallery(label="Visualizations (Auto-generated by Agent)", columns=2)
agent_btn = gr.Button("Run AI Agent (3 Insights + 3 Visualizations)")
with gr.Row():
train_btn = gr.Button("Train Model with Optuna + WandB")
metrics_output = gr.JSON(label="Performance Metrics")
trials_output = gr.DataFrame(label="Top 7 Hyperparameter Trials")
with gr.Row():
explain_btn = gr.Button("SHAP + LIME Explainability")
shap_img = gr.Image(label="SHAP Summary Plot")
lime_img = gr.Image(label="LIME Explanation")
with gr.Row():
compare_btn = gr.Button("Compare Models (A/B Testing)")
compare_output = gr.DataFrame(label="Model Comparison (CV + Metrics)")
compare_img = gr.Image(label="Model Accuracy Plot")
agent_btn.click(fn=analyze_data, inputs=[file_input], outputs=[insights_output, visual_output])
train_btn.click(fn=train_model, inputs=[file_input], outputs=[metrics_output, trials_output])
explain_btn.click(fn=explainability, inputs=[], outputs=[shap_img, lime_img])
compare_btn.click(fn=compare_models, inputs=[], outputs=[compare_output, compare_img])
demo.launch(debug=True)