File size: 3,414 Bytes
f822bfd
 
 
 
 
 
 
e0589c8
f822bfd
 
 
4545835
 
f822bfd
 
4545835
 
 
f822bfd
 
 
4545835
 
 
 
f822bfd
 
 
4545835
 
 
f822bfd
 
 
 
4545835
 
 
f822bfd
 
 
 
 
 
e0589c8
4545835
f822bfd
 
 
 
4545835
f822bfd
4545835
 
 
 
f822bfd
e0589c8
f822bfd
 
e0589c8
f822bfd
4545835
 
f822bfd
 
 
 
 
 
4545835
f822bfd
4545835
 
 
f822bfd
 
4545835
 
 
 
f822bfd
 
4545835
f822bfd
4545835
f822bfd
 
4545835
 
 
f822bfd
 
4545835
f822bfd
 
 
 
4545835
f822bfd
 
 
 
 
 
4545835
f822bfd
 
4545835
f822bfd
 
4545835
f822bfd
 
 
 
 
 
 
e0589c8
f822bfd
4545835
 
 
f822bfd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4545835
f822bfd
4545835
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
# /// script
# requires-python = ">=3.12"
# dependencies = [
#     "ell-ai==0.0.14",
#     "marimo",
#     "openai==1.53.0",
#     "polars==1.12.0",
#     "altair==5.4.1",
# ]
# ///

import marimo

__generated_with = "0.9.20"
app = marimo.App(width="medium")


@app.cell
def __(mo):
    mo.md(r"""# Generative UI Chatbot""")
    return


@app.cell
def __(mo):
    _default_dataset = "hf://datasets/scikit-learn/Fish/Fish.csv"
    dataset_input = mo.ui.text(value=_default_dataset, full_width=True)
    return (dataset_input,)


@app.cell
def __(dataset_input, mo):
    mo.md(f"""
    This chatbot can answer questions about the following dataset: {dataset_input}
    """)
    return


@app.cell
def __(dataset_input, mo, pl):
    # Grab a dataset
    try:
        df = pl.read_csv(dataset_input.value)
        mo.output.replace(
            mo.md(f"Loaded dataset with {len(df)} rows and {len(df.columns)} columns.")
        )
    except Exception as e:
        df = pl.DataFrame()
        mo.output.replace(
            mo.md(f"""**Error loading dataset**:\n\n{e}""").callout(kind="danger")
        )
    return (df,)


@app.cell
def __():
    import os

    import marimo as mo
    import polars as pl

    return mo, os, pl


@app.cell
def __(mo, os):
    api_key_input = mo.ui.text(
        label="OpenAI API Key",
        kind="password",
        value=os.environ.get("OPENAI_API_KEY") or "",
    )
    return (api_key_input,)


@app.cell
def __(api_key_input):
    api_key_input
    return


@app.cell
def __(api_key_input, mo):
    from openai import Client

    mo.stop(not api_key_input.value, mo.md("_Missing API key_"))

    client = Client(api_key=api_key_input.value)
    return Client, client


@app.cell
def __(df, mo):
    import ell

    @ell.tool()
    def chart_data(x_encoding: str, y_encoding: str, color: str):
        """Generate an altair chart"""
        import altair as alt

        return (
            alt.Chart(df)
            .mark_circle()
            .encode(x=x_encoding, y=y_encoding, color=color)
            .properties(width=500)
        )

    @ell.tool()
    def filter_dataset(sql_query: str):
        """
        Filter a polars dataframe using SQL. Please only use fields from the schema.
        When referring to the table in SQL, call it 'data'.
        """
        filtered = df.sql(sql_query, table_name="data")
        return mo.ui.table(
            filtered,
            label=f"```sql\n{sql_query}\n```",
            selection=None,
            show_column_summaries=False,
        )

    return chart_data, ell, filter_dataset


@app.cell
def __(chart_data, client, df, ell, filter_dataset, mo):
    @ell.complex(model="gpt-4o", tools=[chart_data, filter_dataset], client=client)
    def analyze_dataset(prompt: str) -> str:
        """You are a data scientist that can analyze a dataset"""
        return f"I have a dataset with schema: {df.schema}. \n{prompt}"

    def my_model(messages):
        response = analyze_dataset(messages)
        if response.tool_calls:
            return response.tool_calls[0]()
        return response.text

    mo.ui.chat(
        my_model,
        prompts=[
            "Can you chart two columns of your choosing?",
            "Can you find the min, max of all numeric fields?",
            "What is the sum of {{column}}?",
        ],
    )
    return analyze_dataset, my_model


if __name__ == "__main__":
    app.run()