Spaces:
Running
Running
File size: 16,136 Bytes
1721aea f5c8ef7 1721aea b7dc123 1721aea f5c8ef7 1721aea f5c8ef7 1721aea f5c8ef7 1721aea f5c8ef7 1721aea f5c8ef7 1721aea f5c8ef7 1721aea f5c8ef7 ca3ce07 f5c8ef7 ca3ce07 f5c8ef7 ca3ce07 f5c8ef7 ca3ce07 f5c8ef7 1721aea f5c8ef7 1721aea f5c8ef7 1721aea f5c8ef7 1721aea f5c8ef7 1721aea f5c8ef7 1721aea f5c8ef7 1721aea f5c8ef7 1721aea f5c8ef7 1721aea f5c8ef7 1721aea f5c8ef7 1721aea f5c8ef7 1721aea f5c8ef7 1721aea f5c8ef7 1721aea f5c8ef7 1721aea f5c8ef7 1721aea f5c8ef7 1721aea f5c8ef7 1721aea f5c8ef7 1721aea f5c8ef7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 |
import os
import sys
import json
from pathlib import Path
import gradio as gr
import time
import smtplib
from email.message import EmailMessage
# Make your repo importable (expecting a folder named causal-agent at repo root)
sys.path.append(str(Path(__file__).parent / "causal-agent"))
EXAMPLE_CSV_PATH = os.getenv(
"EXAMPLE_CSV_PATH",
str(Path(__file__).parent )
)
from auto_causal.agent import run_causal_analysis # uses env for provider/model
# -------- LLM config (OpenAI only; key via HF Secrets) --------
os.environ.setdefault("LLM_PROVIDER", "openai")
os.environ.setdefault("LLM_MODEL", "gpt-4o")
def _get_openai_client():
if os.getenv("LLM_PROVIDER", "openai") != "openai":
raise RuntimeError("Only LLM_PROVIDER=openai is supported in this demo.")
if not os.getenv("OPENAI_API_KEY"):
raise RuntimeError("Missing OPENAI_API_KEY (set as a Space Secret).")
from openai import OpenAI
return OpenAI()
SYSTEM_PROMPT = """You are an expert in statistics and causal inference.
You will be given:
1) The original research question.
2) The analysis method used.
3) The estimated effects, confidence intervals, standard errors, and p-values for each treatment group compared to the control group.
4) A brief dataset description.
Your task is to produce a clear, concise, and non-technical summary that:
- Directly answers the research question.
- States whether the effect is statistically significant.
- Quantifies the effect size and explains what it means in practical terms (e.g., percentage point change).
- Mentions the method used in one sentence.
- Optionally ranks the treatment effects from largest to smallest if multiple treatments exist.
Formatting rules:
- Use bullet points or short paragraphs.
- Report effect sizes to two decimal places.
- Clearly state the interpretation in plain English without technical jargon.
Example Output Structure:
- **Method:** [Name of method + 1-line rationale]
- **Key Finding:** [Main answer to the research question]
- **Details:**
- [Treatment name]: +X.XX percentage points (95% CI: [L, U]), p < 0.001 — [Significance comment]
- …
- **Rank Order of Effects:** [Largest → Smallest]
"""
def _extract_minimal_payload(agent_result: dict) -> dict:
res = agent_result or {}
results = res.get("results", {}) if isinstance(res.get("results"), dict) else {}
inner = results.get("results", {}) if isinstance(results.get("results"), dict) else {}
dataset_analysis = results.get("dataset_analysis", {}) if isinstance(results.get("dataset_analysis"), dict) else {}
question = (
results.get("original_query")
or dataset_analysis.get("original_query")
or res.get("query")
or "N/A"
)
method = (
inner.get("method_used")
or res.get("method_used")
or results.get("method_used")
or "N/A"
)
effect_estimate = inner.get("effect_estimate") or res.get("effect_estimate") or {}
confidence_interval = inner.get("confidence_interval") or res.get("confidence_interval") or {}
standard_error = inner.get("standard_error") or res.get("standard_error") or {}
p_value = inner.get("p_value") or res.get("p_value") or {}
dataset_desc = results.get("dataset_description") or res.get("dataset_description") or "N/A"
return {
"original_question": question,
"method_used": method,
"estimates": {
"effect_estimate": effect_estimate,
"confidence_interval": confidence_interval,
"standard_error": standard_error,
"p_value": p_value,
},
"dataset_description": dataset_desc,
}
def _summarize_with_llm(payload: dict) -> str:
client = _get_openai_client()
model_name = os.getenv("LLM_MODEL", "gpt-4o-mini")
user_content = "Summarize the following causal analysis results:\n\n" + json.dumps(payload, indent=2, ensure_ascii=False)
resp = client.chat.completions.create(
model=model_name,
messages=[{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": user_content}],
temperature=0
)
return resp.choices[0].message.content.strip()
def _html_panel(title, body_html):
return f"""
<div style='padding:15px;border:1px solid #ddd;border-radius:8px;margin:5px 0;background-color:#333333;'>
<h3 style='margin:0 0 10px 0;font-size:18px;'>{title}</h3>
<div style='line-height:1.6;'>{body_html}</div>
</div>
"""
def _warn_html(text):
return f"<div style='padding:10px;border:1px solid #ffc107;border-radius:5px;color:#ffc107;background-color:#333333;'>⚠️ {text}</div>"
def _err_html(text):
return f"<div style='padding:10px;border:1px solid #dc3545;border-radius:5px;color:#dc3545;background-color:#333333;'>❌ {text}</div>"
def _ok_html(text):
return f"<div style='padding:10px;border:1px solid #2ea043;border-radius:5px;color:#2ea043;background-color:#333333;'>✅ {text}</div>"
# --- Email support ---
import base64, json, requests
from email.message import EmailMessage
def _gmail_access_token() -> str:
token_url = "https://oauth2.googleapis.com/token"
data = {
"client_id": os.getenv("GMAIL_CLIENT_ID"),
"client_secret": os.getenv("GMAIL_CLIENT_SECRET"),
"refresh_token": os.getenv("GMAIL_REFRESH_TOKEN"),
"grant_type": "refresh_token",
}
r = requests.post(token_url, data=data, timeout=20)
r.raise_for_status()
return r.json()["access_token"]
def send_email(recipient: str, subject: str, body_text: str,
attachment_name: str = None, attachment_json: dict = None) -> str:
"""
Sends via Gmail API. Returns '' on success, or an error string.
"""
from_addr = os.getenv("EMAIL_FROM")
if not all([os.getenv("GMAIL_CLIENT_ID"), os.getenv("GMAIL_CLIENT_SECRET"),
os.getenv("GMAIL_REFRESH_TOKEN"), from_addr]):
return "Gmail API not configured (set GMAIL_CLIENT_ID, GMAIL_CLIENT_SECRET, GMAIL_REFRESH_TOKEN, EMAIL_FROM)."
try:
# Build MIME message
msg = EmailMessage()
msg["From"] = from_addr
msg["To"] = recipient
msg["Subject"] = subject
msg.set_content(body_text)
if attachment_json is not None and attachment_name:
payload = json.dumps(attachment_json, indent=2).encode("utf-8")
msg.add_attachment(payload, maintype="application", subtype="json", filename=attachment_name)
# Base64url encode the raw RFC822 message
raw = base64.urlsafe_b64encode(msg.as_bytes()).decode("utf-8")
# Get access token and send
access_token = _gmail_access_token()
api_url = "https://gmail.googleapis.com/gmail/v1/users/me/messages/send"
headers = {"Authorization": f"Bearer {access_token}", "Content-Type": "application/json"}
r = requests.post(api_url, headers=headers, json={"raw": raw}, timeout=20)
if r.status_code >= 400:
return f"Gmail API error {r.status_code}: {r.text[:300]}"
return ""
except Exception as e:
return f"Email send failed: {e}"
def run_agent(query: str, csv_path: str, dataset_description: str, email: str):
start = time.time()
processing_html = _html_panel("🔄 Analysis in Progress...", "<div style='font-size:14px;color:#bbb;'>This may take 1–2 minutes depending on dataset size.</div>")
yield (processing_html, processing_html, processing_html, {"status": "Processing started..."})
if not os.getenv("OPENAI_API_KEY"):
yield (_err_html("Set a Space Secret named OPENAI_API_KEY"), "", "", {})
return
if not csv_path:
yield (_warn_html("Please upload a CSV dataset."), "", "", {})
return
try:
step_html = _html_panel("📊 Running Causal Analysis...", "<div style='font-size:14px;color:#bbb;'>Analyzing dataset and selecting optimal method…</div>")
yield (step_html, step_html, step_html, {"status": "Running causal analysis..."})
result = run_causal_analysis(
query=(query or "What is the effect of treatment T on outcome Y controlling for X?").strip(),
dataset_path=csv_path,
dataset_description=(dataset_description or "").strip(),
)
llm_html = _html_panel("🤖 Generating Summary...", "<div style='font-size:14px;color:#bbb;'>Creating human-readable interpretation…</div>")
yield (llm_html, llm_html, llm_html, {"status": "Generating explanation...", "raw_analysis": result if isinstance(result, dict) else {}})
except Exception as e:
yield (_err_html(str(e)), "", "", {})
return
try:
payload = _extract_minimal_payload(result if isinstance(result, dict) else {})
method = payload.get("method_used", "N/A")
method_html = _html_panel("Selected Method", f"<p style='margin:0;font-size:16px;'>{method}</p>")
effect_estimate = payload.get("estimates", {}).get("effect_estimate", {})
if effect_estimate:
effects_html = _html_panel("Effect Estimates", f"<pre style='white-space:pre-wrap;margin:0;'>{json.dumps(effect_estimate, indent=2)}</pre>")
else:
effects_html = _warn_html("No effect estimates found")
try:
explanation = _summarize_with_llm(payload)
explanation_html = _html_panel("Detailed Explanation", f"<div style='white-space:pre-wrap;'>{explanation}</div>")
except Exception as e:
explanation_html = _warn_html(f"LLM summary failed: {e}")
except Exception as e:
yield (_err_html(f"Failed to parse results: {e}"), "", "", {})
return
# Optional email send (best-effort)
elapsed = time.time() - start
if email and "@" in email:
# Always send once results are ready; if you prefer thresholded behavior, check (elapsed > 600)
subject = "Causal Agent Results"
body = (
"Here are your Causal Agent results.\n\n"
f"Question: {payload.get('original_question','N/A')}\n"
f"Method: {method}\n\n"
f"Summary:\n{explanation}\n\n"
"Raw JSON is attached.\n"
)
email_err = send_email(
recipient=email.strip(),
subject=subject,
body_text=body,
attachment_name="causal_results.json",
attachment_json=(result if isinstance(result, dict) else {"results": result})
)
if email_err:
explanation_html += _warn_html(email_err)
else:
explanation_html += _ok_html(f"Results emailed to {email.strip()}")
yield (method_html, effects_html, explanation_html, result if isinstance(result, dict) else {})
with gr.Blocks() as demo:
gr.Markdown("# Causal AI Scientist")
# gr.Markdown(
# """
# **Tips**
# - Be specific about your treatment, outcome, and control variables.
# - Include relevant context in the dataset description.
# - If you enter an email, we’ll send results when ready (only if SMTP is configured via env).
# """
# )
gr.Markdown(
"Upload your dataset and ask causal questions in natural language. "
"The system will automatically select the appropriate causal inference method and provide clear explanations."
)
with gr.Row():
query = gr.Textbox(
label="Your causal question (natural language)",
placeholder="e.g., What is the effect of attending the program (T) on income (Y), controlling for education and age?",
lines=2,
)
with gr.Row():
csv_file = gr.File(label="Dataset (CSV)", file_types=[".csv"], type="filepath")
dataset_description = gr.Textbox(
label="Dataset description (optional)",
placeholder="Brief schema, how it was collected, time period, units, treatment/outcome variables, etc.",
lines=4,
)
# NEW: optional email field
email = gr.Textbox(
label="Email (optional)",
placeholder="you@example.com — we'll email the results when ready (if email is configured).",
)
# Helpful examples (question + description)
gr.Examples(
examples=[
[
"Does the adoption of the industrial reform policy increase the production output in factories?",
"""This dataset has been compiled to study the effect of an industrial reform policy on production output in several manufacturing factories. The data was collected over two-time frames, before and after the reform, from a group of factories which adopted the reform and another group that did not. The data collected includes continuous variables such as labor hours spent on production and the amount of raw materials used. Binary variables include the use of automation, energy efficiency of the machines used (energy efficient or not), and worker satisfaction (satisfied or not).
- factory_id: Unique identifier for each factory
- post_reform: Indicator if the data was collected after the reform (1) or before the reform (0)
- labor_hours: The number of labor hours spent on production
- raw_materials: The quantity of raw materials used in kilograms
- automation_use: Indicator if the factory uses automation in its production process (1) or not (0)
- energy_efficiency: Indicator if the factory uses energy-efficient machines (1) or not (0)
- worker_satisfaction: Indicator if workers reported being satisfied with their work environment (1) or not (0)""",
EXAMPLE_CSV_PATH + '/did_canonical_data_1.csv'
],
[
"Does taking the newly developed medication have an impact on improving the lung capacity of patients with chronic obstructive pulmonary disease?",
"""This dataset is collected from a clinical trial study conducted by a pharmaceutical company. The study aims to understand if their newly developed medication can enhance the lung capacity of patients suffering from chronic obstructive pulmonary disease (COPD). Participants for the study, who are all COPD patients, were recruited from various healthcare centers and were randomly assigned to either receive the new medication or a placebo. Data was collected on various factors including the age of the participants, the number of years they have been smoking, their gender, whether or not they have a history of smoking, and whether or not they have a regular exercise habit.
'age' is the age of the participants in years.
'smoking_years' is the number of years the participant has been smoking.
'gender' is a binary variable where 1 represents male and 0 represents female.
'smoking_history' is a binary variable where 1 indicates the participant has a history of smoking, while 0 indicates no such history.
'exercise_habit' is a binary variable where 1 indicates the participant exercises regularly, while 0 indicates the participant does not.
'new_medication' is a binary variable where 1 indicates the participant was assigned the new medication, while 0 indicates the participant was assigned a placebo.
'lung_capacity' is the measured lung capacity of the participant.""",
EXAMPLE_CSV_PATH + '/rct_data_4.csv',
],
],
inputs=[query, dataset_description, csv_file], # include the file component here
label="Quick Examples (click to fill)",
)
run_btn = gr.Button("Run analysis", variant="primary")
with gr.Row():
with gr.Column(scale=1):
method_out = gr.HTML(label="Selected Method")
with gr.Column(scale=1):
effects_out = gr.HTML(label="Effect Estimates")
with gr.Row():
explanation_out = gr.HTML(label="Detailed Explanation")
with gr.Accordion("Raw Results (Advanced)", open=False):
raw_results = gr.JSON(label="Complete Analysis Output", show_label=False)
run_btn.click(
fn=run_agent,
inputs=[query, csv_file, dataset_description, email],
outputs=[method_out, effects_out, explanation_out, raw_results],
show_progress=True
)
if __name__ == "__main__":
demo.queue().launch()
|