hra commited on
Commit
d20cf4b
Β·
verified Β·
1 Parent(s): 9683749

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +72 -78
app.py CHANGED
@@ -1,6 +1,5 @@
1
  """
2
- 🌐 What-If Lab β€” live 2Γ—2 scenario mapper (MCP-enabled)
3
- Harsha Angeri / ChatGPT helper β€’ MIT
4
  """
5
 
6
  import io, os, time, uuid, threading, hashlib, warnings
@@ -18,143 +17,138 @@ try:
18
  import openai
19
  except ImportError:
20
  openai = None
21
- warnings.warn("`openai` package not found; embeddings disabled")
22
 
23
- # ───────────────────────── helpers ──────────────────────────
24
- def utc() -> str:
25
  return datetime.utcnow().strftime("%H:%M:%S")
26
 
27
- def log(txt: str):
28
- print(f"[{utc()}] {txt}", flush=True)
29
 
30
- def web_search(q: str, k: int = 20) -> list[str]:
31
- with DDGS() as dd:
32
- return [f"{r['title']} – {r.get('body','')}" for r in dd.text(q, max_results=k)]
33
 
34
- def hash_xy(txt: str):
35
- h = int(hashlib.sha256(txt.encode()).hexdigest(), 16)
36
  return ((h % 1000) / 500 - 1, ((h >> 10) % 1000) / 500 - 1)
37
 
38
- def embed(txts: list[str], key: str):
39
  if openai is None:
40
- raise RuntimeError("install `openai` for embeddings")
41
  if not key:
42
  raise RuntimeError("OpenAI key missing")
43
  openai.api_key = key
44
- res = openai.embeddings.create(model="text-embedding-3-small", input=txts)
45
  return np.array([d.embedding for d in res.data])
46
 
47
- def cluster(snips, embeds):
48
  if embeds is None:
49
- return [(*hash_xy(s), s[:50] + ("…" if len(s) > 50 else "")) for s in snips[:16]]
50
 
51
  k = min(max(len(snips) // 5, 4), 12)
52
- km = KMeans(n_clusters=k, n_init="auto", random_state=0).fit(embeds)
53
- pca2 = PCA(n_components=2, random_state=0).fit_transform(km.cluster_centers_)
54
-
55
- labels = [snips[int(np.argmin(np.linalg.norm(embeds - c, axis=1)))][:50] + "…"
56
- for c in km.cluster_centers_]
57
-
58
- xs, ys = pca2[:, 0], pca2[:, 1]
59
  xs = (xs - xs.min()) / (np.ptp(xs) + 1e-4) * 2 - 1
60
  ys = (ys - ys.min()) / (np.ptp(ys) + 1e-4) * 2 - 1
61
  return list(zip(xs, ys, labels))
62
 
63
- def make_plot(points, ax1, ax2):
64
  fig, ax = plt.subplots(figsize=(5, 5))
65
- ax.axhline(0, color="black", lw=0.8); ax.axvline(0, color="black", lw=0.8)
 
66
  ax.set_xlim(-1.1, 1.1); ax.set_ylim(-1.1, 1.1)
67
  ax.set_xlabel(ax1, fontsize=12, fontweight="bold")
68
  ax.set_ylabel(ax2, fontsize=12, fontweight="bold")
69
- for x, y, label in points:
70
- ax.scatter(x, y, s=40); ax.text(x, y, label, fontsize=8, ha="left", va="bottom")
71
  buf = io.BytesIO(); fig.tight_layout(); fig.savefig(buf, format="png"); plt.close(fig)
72
  buf.seek(0)
73
- return PILImage.open(buf)
74
 
75
- # ─────────────────── session store & worker ──────────────────
76
  LABS = {}
77
 
78
  def worker(sid):
79
- cfg = LABS[sid]
80
- a1, a2, refresh, key = cfg["axis1"], cfg["axis2"], cfg["refresh"], cfg["key"]
81
- log(f"Worker started {sid}")
82
  while cfg["running"]:
83
  t0 = time.perf_counter()
84
  try:
85
- snips = web_search(f"{a1} {a2}", 20)
86
- cfg["logs"].append(f"[{utc()}] {len(snips)} signals")
87
  except Exception as e:
88
- cfg["logs"].append(f"[{utc()}] Search ERROR: {e}"); time.sleep(refresh); continue
89
  try:
90
- embeds = embed(snips, key) if key else None
91
  except Exception as e:
92
- cfg["logs"].append(f"[{utc()}] Embed fallback: {e}"); embeds = None
93
  try:
94
- cfg["points"] = cluster(snips, embeds)
95
- cfg["logs"].append(f"[{utc()}] Plot with {len(cfg['points'])} pts")
96
  except Exception as e:
97
- cfg["logs"].append(f"[{utc()}] Cluster ERROR: {e}")
98
- time.sleep(max(refresh - (time.perf_counter() - t0), 0.3))
99
- log(f"Worker stopped {sid}")
100
 
101
- # ─────────────────────── MCP endpoints ───────────────────────
102
  def scenario_lab(axis1, axis2, refresh=20, openai_key=None):
103
  axis1, axis2 = axis1.strip(), axis2.strip()
104
  if len(axis1) < 3 or len(axis2) < 3:
105
- raise gr.Error("Axis prompts must be β‰₯3 chars")
106
  sid = str(uuid.uuid4())
107
- LABS[sid] = dict(
108
- axis1=axis1.title(), axis2=axis2.title(),
109
- refresh=max(5, int(refresh)), key=openai_key or os.getenv("OPENAI_API_KEY", ""),
110
- points=[], logs=[f"[{utc()}] Session created"], running=True
111
- )
112
  threading.Thread(target=worker, args=(sid,), daemon=True).start()
113
- log(f"Created {sid}")
114
  return {"session_id": sid}
115
 
116
  def get_plot(session_id):
117
  cfg = LABS.get(session_id) or gr.Error("Unknown session_id")
118
- return make_plot(cfg["points"], cfg["axis1"], cfg["axis2"]) if cfg["points"] else \
119
- make_plot([], cfg["axis1"], cfg["axis2"])
 
120
 
121
  def get_logs(session_id):
122
  cfg = LABS.get(session_id) or gr.Error("Unknown session_id")
123
  return cfg["logs"][-40:]
124
 
125
- # ─────────────────────────── UI ──────────────────────────────
126
  with gr.Blocks(title="🌐 What-If Lab") as demo:
127
- gr.Markdown("## 🌐 What-If Lab β€” live 2Γ—2 scenario mapper")
128
  with gr.Row():
129
- in_a1 = gr.Textbox(label="Axis A"); in_a2 = gr.Textbox(label="Axis B")
130
  with gr.Row():
131
- in_ref = gr.Slider(5, 60, value=20, step=1, label="Refresh (s)")
132
- in_key = gr.Textbox(label="OpenAI key (opt.)", type="password")
133
- btn = gr.Button("πŸš€ Run")
134
- out_img = gr.Image(label="Live 2Γ—2", height=400)
135
- out_log = gr.Textbox(label="Log", lines=20, interactive=False)
136
- sid = gr.State("")
137
-
138
- # launch session
139
- btn.click(lambda a,b,r,k: scenario_lab(a,b,r,k)["session_id"],
140
- inputs=[in_a1, in_a2, in_ref, in_key],
141
- outputs=sid)
142
-
143
- # polling callback
144
- def ui_poll(s):
145
- if not s: return gr.update(), gr.update()
146
  try:
147
- return gr.update(value=get_plot(s)), gr.update(value="\n".join(get_logs(s)))
148
  except Exception as e:
149
  return gr.update(), gr.update(value=f"Error: {e}")
150
 
151
- # first load
152
- demo.load(ui_poll, inputs=[sid], outputs=[out_img, out_log])
153
 
154
- # 3-second timer (new API)
155
- timer = gr.Timer(3)
156
- timer.then(ui_poll, inputs=[sid], outputs=[out_img, out_log])
 
 
 
 
157
 
158
- # ─────────────────────���─── launch ────────────────────────────
159
  if __name__ == "__main__":
160
  demo.launch(show_api=True, mcp_server=True, share=True)
 
1
  """
2
+ What-If Lab – live 2Γ—2 scenario mapper (MCP-enabled, fully patched for Spaces)
 
3
  """
4
 
5
  import io, os, time, uuid, threading, hashlib, warnings
 
17
  import openai
18
  except ImportError:
19
  openai = None
20
+ warnings.warn("`openai` package not found; OpenAI clustering disabled")
21
 
22
+ # Helpers
23
+ def utc_ts() -> str:
24
  return datetime.utcnow().strftime("%H:%M:%S")
25
 
26
+ def log(msg: str):
27
+ print(f"[{utc_ts()}] {msg}", flush=True)
28
 
29
+ def search_web(query: str, k: int = 20) -> list[str]:
30
+ with DDGS() as ddgs:
31
+ return [f"{r['title']} – {r.get('body','')}" for r in ddgs.text(query, max_results=k)]
32
 
33
+ def deterministic_xy(text: str):
34
+ h = int(hashlib.sha256(text.encode()).hexdigest(), 16)
35
  return ((h % 1000) / 500 - 1, ((h >> 10) % 1000) / 500 - 1)
36
 
37
+ def embed_texts(texts: list[str], key: str):
38
  if openai is None:
39
+ raise RuntimeError("`openai` not installed")
40
  if not key:
41
  raise RuntimeError("OpenAI key missing")
42
  openai.api_key = key
43
+ res = openai.embeddings.create(model="text-embedding-3-small", input=texts)
44
  return np.array([d.embedding for d in res.data])
45
 
46
+ def cluster_points(snips: list[str], embeds):
47
  if embeds is None:
48
+ return [(*deterministic_xy(s), s[:50] + ("…" if len(s) > 50 else "")) for s in snips[:16]]
49
 
50
  k = min(max(len(snips) // 5, 4), 12)
51
+ km = KMeans(n_clusters=k, n_init="auto", random_state=0).fit(embeds)
52
+ pca2d = PCA(n_components=2, random_state=0).fit_transform(km.cluster_centers_)
53
+ labels = []
54
+ for c in km.cluster_centers_:
55
+ labels.append(snips[int(np.argmin(np.linalg.norm(embeds - c, axis=1)))][:50] + "…")
56
+ xs, ys = pca2d[:, 0], pca2d[:, 1]
 
57
  xs = (xs - xs.min()) / (np.ptp(xs) + 1e-4) * 2 - 1
58
  ys = (ys - ys.min()) / (np.ptp(ys) + 1e-4) * 2 - 1
59
  return list(zip(xs, ys, labels))
60
 
61
+ def draw_plot(points, ax1, ax2):
62
  fig, ax = plt.subplots(figsize=(5, 5))
63
+ ax.axhline(0, color="black", lw=0.8)
64
+ ax.axvline(0, color="black", lw=0.8)
65
  ax.set_xlim(-1.1, 1.1); ax.set_ylim(-1.1, 1.1)
66
  ax.set_xlabel(ax1, fontsize=12, fontweight="bold")
67
  ax.set_ylabel(ax2, fontsize=12, fontweight="bold")
68
+ for x, y, lbl in points:
69
+ ax.scatter(x, y, s=40); ax.text(x, y, lbl, fontsize=8, ha="left", va="bottom")
70
  buf = io.BytesIO(); fig.tight_layout(); fig.savefig(buf, format="png"); plt.close(fig)
71
  buf.seek(0)
72
+ return PILImage.open(buf) # <-- return PIL.Image object
73
 
74
+ # Session store & worker
75
  LABS = {}
76
 
77
  def worker(sid):
78
+ cfg = LABS[sid]; a1, a2, ref, key = cfg["axis1"], cfg["axis2"], cfg["refresh"], cfg["key"]
79
+ log(f"Worker started for {sid}")
 
80
  while cfg["running"]:
81
  t0 = time.perf_counter()
82
  try:
83
+ snips = search_web(f"{a1} {a2}", 20); cfg["logs"].append(f"[{utc_ts()}] {len(snips)} signals")
 
84
  except Exception as e:
85
+ cfg["logs"].append(f"[{utc_ts()}] Search ERROR: {e}"); time.sleep(ref); continue
86
  try:
87
+ embeds = embed_texts(snips, key) if key else None
88
  except Exception as e:
89
+ cfg["logs"].append(f"[{utc_ts()}] Embed fallback: {e}"); embeds = None
90
  try:
91
+ cfg["points"] = cluster_points(snips, embeds)
92
+ cfg["logs"].append(f"[{utc_ts()}] Plot with {len(cfg['points'])} pts")
93
  except Exception as e:
94
+ cfg["logs"].append(f"[{utc_ts()}] Cluster ERROR: {e}")
95
+ time.sleep(max(ref - (time.perf_counter() - t0), 0.2))
96
+ log(f"Worker stopped for {sid}")
97
 
98
+ # MCP-exposed endpoints
99
  def scenario_lab(axis1, axis2, refresh=20, openai_key=None):
100
  axis1, axis2 = axis1.strip(), axis2.strip()
101
  if len(axis1) < 3 or len(axis2) < 3:
102
+ raise gr.Error("Axis prompts must be β‰₯3 chars.")
103
  sid = str(uuid.uuid4())
104
+ LABS[sid] = dict(axis1=axis1.title(), axis2=axis2.title(),
105
+ refresh=max(5, int(refresh)), key=openai_key or os.getenv("OPENAI_API_KEY",""),
106
+ points=[], logs=[f"[{utc_ts()}] Session created"], running=True)
 
 
107
  threading.Thread(target=worker, args=(sid,), daemon=True).start()
108
+ log(f"Created session {sid}")
109
  return {"session_id": sid}
110
 
111
  def get_plot(session_id):
112
  cfg = LABS.get(session_id) or gr.Error("Unknown session_id")
113
+ if not cfg["points"]:
114
+ return draw_plot([], cfg["axis1"], cfg["axis2"])
115
+ return draw_plot(cfg["points"], cfg["axis1"], cfg["axis2"])
116
 
117
  def get_logs(session_id):
118
  cfg = LABS.get(session_id) or gr.Error("Unknown session_id")
119
  return cfg["logs"][-40:]
120
 
121
+ # Gradio UI
122
  with gr.Blocks(title="🌐 What-If Lab") as demo:
123
+ gr.Markdown("## 🌐 What-If Lab – live scenario 2Γ—2")
124
  with gr.Row():
125
+ a1 = gr.Textbox(label="Axis A"); a2 = gr.Textbox(label="Axis B")
126
  with gr.Row():
127
+ ref = gr.Slider(5, 60, value=20, step=1, label="Refresh (s)")
128
+ key = gr.Textbox(label="OpenAI key (opt.)", type="password"); run = gr.Button("πŸš€ Run")
129
+ img = gr.Image(label="Live 2Γ—2", height=400); log_box = gr.Textbox(label="Log", lines=20, interactive=False)
130
+ sid_state = gr.State("")
131
+
132
+ def ui_launch(x, y, r, k): return scenario_lab(x, y, r, k)["session_id"]
133
+ run.click(ui_launch, [a1, a2, ref, key], sid_state)
134
+
135
+ def ui_poll(sid):
136
+ if not sid: return gr.update(), gr.update()
 
 
 
 
 
137
  try:
138
+ return gr.update(value=get_plot(sid)), gr.update(value="\n".join(get_logs(sid)))
139
  except Exception as e:
140
  return gr.update(), gr.update(value=f"Error: {e}")
141
 
142
+ demo.load(ui_poll, [sid_state], [img, log_box])
 
143
 
144
+ # βœ… FIXED: new Timer syntax for Gradio 4.21+
145
+ gr.Timer(
146
+ fn=ui_poll,
147
+ every=30,
148
+ inputs=[sid_state],
149
+ outputs=[img, log_box]
150
+ )
151
 
152
+ # Launch (UI + MCP)
153
  if __name__ == "__main__":
154
  demo.launch(show_api=True, mcp_server=True, share=True)