|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import marimo |
|
|
|
__generated_with = "0.9.20" |
|
app = marimo.App(width="medium") |
|
|
|
|
|
@app.cell |
|
def __(mo): |
|
mo.md(r"""# Generative UI Chatbot""") |
|
return |
|
|
|
|
|
@app.cell |
|
def __(mo): |
|
_default_dataset = "hf://datasets/scikit-learn/Fish/Fish.csv" |
|
dataset_input = mo.ui.text(value=_default_dataset, full_width=True) |
|
return (dataset_input,) |
|
|
|
|
|
@app.cell |
|
def __(dataset_input, mo): |
|
mo.md(f""" |
|
This chatbot can answer questions about the following dataset: {dataset_input} |
|
""") |
|
return |
|
|
|
|
|
@app.cell |
|
def __(dataset_input, mo, pl): |
|
|
|
try: |
|
df = pl.read_csv(dataset_input.value) |
|
mo.output.replace( |
|
mo.md(f"Loaded dataset with {len(df)} rows and {len(df.columns)} columns.") |
|
) |
|
except Exception as e: |
|
df = pl.DataFrame() |
|
mo.output.replace( |
|
mo.md(f"""**Error loading dataset**:\n\n{e}""").callout(kind="danger") |
|
) |
|
return (df,) |
|
|
|
|
|
@app.cell |
|
def __(): |
|
import os |
|
|
|
import marimo as mo |
|
import polars as pl |
|
|
|
return mo, os, pl |
|
|
|
|
|
@app.cell |
|
def __(mo, os): |
|
api_key_input = mo.ui.text( |
|
label="OpenAI API Key", |
|
kind="password", |
|
value=os.environ.get("OPENAI_API_KEY") or "", |
|
) |
|
return (api_key_input,) |
|
|
|
|
|
@app.cell |
|
def __(api_key_input): |
|
api_key_input |
|
return |
|
|
|
|
|
@app.cell |
|
def __(api_key_input, mo): |
|
from openai import Client |
|
|
|
mo.stop(not api_key_input.value, mo.md("_Missing API key_")) |
|
|
|
client = Client(api_key=api_key_input.value) |
|
return Client, client |
|
|
|
|
|
@app.cell |
|
def __(df, mo): |
|
import ell |
|
|
|
@ell.tool() |
|
def chart_data(x_encoding: str, y_encoding: str, color: str): |
|
"""Generate an altair chart""" |
|
import altair as alt |
|
|
|
return ( |
|
alt.Chart(df) |
|
.mark_circle() |
|
.encode(x=x_encoding, y=y_encoding, color=color) |
|
.properties(width=500) |
|
) |
|
|
|
@ell.tool() |
|
def filter_dataset(sql_query: str): |
|
""" |
|
Filter a polars dataframe using SQL. Please only use fields from the schema. |
|
When referring to the table in SQL, call it 'data'. |
|
""" |
|
filtered = df.sql(sql_query, table_name="data") |
|
return mo.ui.table( |
|
filtered, |
|
label=f"```sql\n{sql_query}\n```", |
|
selection=None, |
|
show_column_summaries=False, |
|
) |
|
|
|
return chart_data, ell, filter_dataset |
|
|
|
|
|
@app.cell |
|
def __(chart_data, client, df, ell, filter_dataset, mo): |
|
@ell.complex(model="gpt-4o", tools=[chart_data, filter_dataset], client=client) |
|
def analyze_dataset(prompt: str) -> str: |
|
"""You are a data scientist that can analyze a dataset""" |
|
return f"I have a dataset with schema: {df.schema}. \n{prompt}" |
|
|
|
def my_model(messages): |
|
response = analyze_dataset(messages) |
|
if response.tool_calls: |
|
return response.tool_calls[0]() |
|
return response.text |
|
|
|
mo.ui.chat( |
|
my_model, |
|
prompts=[ |
|
"Can you chart two columns of your choosing?", |
|
"Can you find the min, max of all numeric fields?", |
|
"What is the sum of {{column}}?", |
|
], |
|
) |
|
return analyze_dataset, my_model |
|
|
|
|
|
if __name__ == "__main__": |
|
app.run() |
|
|