File size: 5,133 Bytes
a83cf51
 
a9934ee
de3813f
deec983
de3813f
 
 
 
 
 
 
9cf2de6
de3813f
a83cf51
de3813f
 
2b0455a
de3813f
a83cf51
de3813f
a83cf51
2b0455a
deec983
a83cf51
2b0455a
d20cf4b
2b0455a
a9934ee
a83cf51
 
a9934ee
 
2b0455a
a9934ee
a83cf51
de3813f
a83cf51
 
de3813f
2b0455a
de3813f
a83cf51
5cdd68c
d20cf4b
a83cf51
2b0455a
a9934ee
 
 
a83cf51
5cdd68c
 
2b0455a
de3813f
a9934ee
de3813f
2b0455a
a9934ee
 
2b0455a
 
a83cf51
 
 
deec983
9cf2de6
2b0455a
 
 
a83cf51
 
2b0455a
a83cf51
 
 
 
 
d47234c
a9934ee
 
a83cf51
 
 
 
 
 
 
 
2b0455a
a83cf51
a9934ee
de3813f
2b0455a
 
 
 
 
a9934ee
 
a83cf51
2b0455a
 
 
 
a83cf51
 
2b0455a
a83cf51
2b0455a
 
 
 
 
a83cf51
2b0455a
a83cf51
2b0455a
 
 
 
de3813f
a83cf51
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
# app.py  –  simple one-shot scenario mapper (MCP-ready)

import io, os, uuid, hashlib, json, warnings
from datetime import datetime
from pathlib import Path

import gradio as gr
import numpy as np
import matplotlib.pyplot as plt
from duckduckgo_search import DDGS
from sklearn.cluster import KMeans
from sklearn.decomposition import PCA
from PIL import Image as PILImage

# ── optional fancy clustering via OpenAI embeddings ────────────────────────
try:
    import openai
except ImportError:
    openai = None
    warnings.warn("`openai` package not found; falling back to hash clustering.")

# ── in-memory store (one run = one entry) ───────────────────────────────────
LABS: dict[str, dict] = {}

# ── helpers ────────────────────────────────────────────────────────────────
def web_search(q: str, k: int = 20):
    with DDGS() as ddgs:
        return [f"{r['title']} – {r.get('body','')}" for r in ddgs.text(q, max_results=k)]

def deterministic_xy(txt: str):
    h = int(hashlib.sha256(txt.encode()).hexdigest(), 16)
    return ((h % 1000) / 500 - 1, ((h >> 12) % 1000) / 500 - 1)

def embed(texts: list[str], key: str | None):
    if openai is None or not key:
        return None                          # Fallback: hashing
    openai.api_key = key
    resp = openai.embeddings.create(model="text-embedding-3-small", input=texts)
    return np.array([d.embedding for d in resp.data])

def cluster(snips: list[str], embeds):
    if embeds is None:
        return [(*deterministic_xy(s), s[:40]) for s in snips[:16]]
    k = min(max(len(snips) // 5, 4), 12)
    km = KMeans(n_clusters=k, n_init="auto", random_state=0).fit(embeds)
    p2 = PCA(2, random_state=0).fit_transform(km.cluster_centers_)
    labels = [
        snips[int(np.argmin(np.linalg.norm(embeds - c, axis=1)))][:40]
        for c in km.cluster_centers_
    ]
    xs, ys = p2[:, 0], p2[:, 1]
    xs = (xs - xs.min()) / (np.ptp(xs) + 1e-4) * 2 - 1
    ys = (ys - ys.min()) / (np.ptp(ys) + 1e-4) * 2 - 1
    return list(zip(xs, ys, labels))

def draw(points, ax1, ax2):
    fig, ax = plt.subplots(figsize=(5, 5))
    ax.axhline(0); ax.axvline(0)
    ax.set_xlim(-1.1, 1.1); ax.set_ylim(-1.1, 1.1)
    ax.set_xlabel(ax1); ax.set_ylabel(ax2)
    for x, y, lbl in points:
        ax.scatter(x, y); ax.text(x, y, lbl, fontsize=8)
    buf = io.BytesIO()
    fig.tight_layout(); fig.savefig(buf, format="png"); plt.close(fig)
    buf.seek(0)
    return PILImage.open(buf)

# ── MCP tool 1 ──────────────────────────────────────────────────────────────
def scenario_lab(axis1: str, axis2: str, openai_key: str | None = None):
    """
    Make a scenario map once and return {session_id, scenarios}.
    Pass \"*****\" or leave blank to use the HF Space secret OPENAI_API_KEY.
    """
    # Secret handling --------------------------------------------------------
    if not openai_key or openai_key.strip() == "*****":
        openai_key = os.getenv("OPENAI_API_KEY", "")

    # Basic validation -------------------------------------------------------
    axis1, axis2 = axis1.strip(), axis2.strip()
    if not axis1 or not axis2:
        raise gr.Error("Both axes are required.")

    # Build map --------------------------------------------------------------
    snippets = web_search(f"{axis1} {axis2}", 20)
    embeds   = embed(snippets, openai_key)
    points   = cluster(snippets, embeds)

    sid = str(uuid.uuid4())
    LABS[sid] = dict(axis1=axis1, axis2=axis2, points=points)
    labels = [p[2] for p in points]

    return {"session_id": sid, "scenarios": labels}

# ── MCP tool 2 ──────────────────────────────────────────────────────────────
def get_plot(session_id: str):
    if session_id not in LABS:
        raise gr.Error("Unknown session_id")
    d = LABS[session_id]
    return draw(d["points"], d["axis1"], d["axis2"])

# ── Minimal UI so you can try it in a browser ───────────────────────────────
create_ui = gr.Interface(
    fn=scenario_lab,
    inputs=[gr.Textbox(label="Axis 1"),
            gr.Textbox(label="Axis 2"),
            gr.Textbox(label="OpenAI key (opt. or β€˜*****’)", type="password")],
    outputs="json",
    title="Create Scenarios",
    api_name="scenario_lab"          # exposes as MCP tool
)

plot_ui = gr.Interface(
    fn=get_plot,
    inputs=gr.Textbox(label="session_id"),
    outputs="image",
    title="Get Plot",
    api_name="get_plot"              # exposes as MCP tool
)

demo = gr.TabbedInterface([create_ui, plot_ui], ["Create", "Plot"])

if __name__ == "__main__":
    demo.launch(mcp_server=True)     # UI + MCP server