what-if-lab / app.py
hra's picture
Update app.py
a83cf51 verified
# app.py – simple one-shot scenario mapper (MCP-ready)
import io, os, uuid, hashlib, json, warnings
from datetime import datetime
from pathlib import Path
import gradio as gr
import numpy as np
import matplotlib.pyplot as plt
from duckduckgo_search import DDGS
from sklearn.cluster import KMeans
from sklearn.decomposition import PCA
from PIL import Image as PILImage
# ── optional fancy clustering via OpenAI embeddings ────────────────────────
try:
import openai
except ImportError:
openai = None
warnings.warn("`openai` package not found; falling back to hash clustering.")
# ── in-memory store (one run = one entry) ───────────────────────────────────
LABS: dict[str, dict] = {}
# ── helpers ────────────────────────────────────────────────────────────────
def web_search(q: str, k: int = 20):
with DDGS() as ddgs:
return [f"{r['title']} – {r.get('body','')}" for r in ddgs.text(q, max_results=k)]
def deterministic_xy(txt: str):
h = int(hashlib.sha256(txt.encode()).hexdigest(), 16)
return ((h % 1000) / 500 - 1, ((h >> 12) % 1000) / 500 - 1)
def embed(texts: list[str], key: str | None):
if openai is None or not key:
return None # Fallback: hashing
openai.api_key = key
resp = openai.embeddings.create(model="text-embedding-3-small", input=texts)
return np.array([d.embedding for d in resp.data])
def cluster(snips: list[str], embeds):
if embeds is None:
return [(*deterministic_xy(s), s[:40]) for s in snips[:16]]
k = min(max(len(snips) // 5, 4), 12)
km = KMeans(n_clusters=k, n_init="auto", random_state=0).fit(embeds)
p2 = PCA(2, random_state=0).fit_transform(km.cluster_centers_)
labels = [
snips[int(np.argmin(np.linalg.norm(embeds - c, axis=1)))][:40]
for c in km.cluster_centers_
]
xs, ys = p2[:, 0], p2[:, 1]
xs = (xs - xs.min()) / (np.ptp(xs) + 1e-4) * 2 - 1
ys = (ys - ys.min()) / (np.ptp(ys) + 1e-4) * 2 - 1
return list(zip(xs, ys, labels))
def draw(points, ax1, ax2):
fig, ax = plt.subplots(figsize=(5, 5))
ax.axhline(0); ax.axvline(0)
ax.set_xlim(-1.1, 1.1); ax.set_ylim(-1.1, 1.1)
ax.set_xlabel(ax1); ax.set_ylabel(ax2)
for x, y, lbl in points:
ax.scatter(x, y); ax.text(x, y, lbl, fontsize=8)
buf = io.BytesIO()
fig.tight_layout(); fig.savefig(buf, format="png"); plt.close(fig)
buf.seek(0)
return PILImage.open(buf)
# ── MCP tool 1 ──────────────────────────────────────────────────────────────
def scenario_lab(axis1: str, axis2: str, openai_key: str | None = None):
"""
Make a scenario map once and return {session_id, scenarios}.
Pass \"*****\" or leave blank to use the HF Space secret OPENAI_API_KEY.
"""
# Secret handling --------------------------------------------------------
if not openai_key or openai_key.strip() == "*****":
openai_key = os.getenv("OPENAI_API_KEY", "")
# Basic validation -------------------------------------------------------
axis1, axis2 = axis1.strip(), axis2.strip()
if not axis1 or not axis2:
raise gr.Error("Both axes are required.")
# Build map --------------------------------------------------------------
snippets = web_search(f"{axis1} {axis2}", 20)
embeds = embed(snippets, openai_key)
points = cluster(snippets, embeds)
sid = str(uuid.uuid4())
LABS[sid] = dict(axis1=axis1, axis2=axis2, points=points)
labels = [p[2] for p in points]
return {"session_id": sid, "scenarios": labels}
# ── MCP tool 2 ──────────────────────────────────────────────────────────────
def get_plot(session_id: str):
if session_id not in LABS:
raise gr.Error("Unknown session_id")
d = LABS[session_id]
return draw(d["points"], d["axis1"], d["axis2"])
# ── Minimal UI so you can try it in a browser ───────────────────────────────
create_ui = gr.Interface(
fn=scenario_lab,
inputs=[gr.Textbox(label="Axis 1"),
gr.Textbox(label="Axis 2"),
gr.Textbox(label="OpenAI key (opt. or β€˜*****’)", type="password")],
outputs="json",
title="Create Scenarios",
api_name="scenario_lab" # exposes as MCP tool
)
plot_ui = gr.Interface(
fn=get_plot,
inputs=gr.Textbox(label="session_id"),
outputs="image",
title="Get Plot",
api_name="get_plot" # exposes as MCP tool
)
demo = gr.TabbedInterface([create_ui, plot_ui], ["Create", "Plot"])
if __name__ == "__main__":
demo.launch(mcp_server=True) # UI + MCP server