mylessss's picture
add altair
e0589c8
# /// script
# requires-python = ">=3.12"
# dependencies = [
# "ell-ai==0.0.14",
# "marimo",
# "openai==1.53.0",
# "polars==1.12.0",
# "altair==5.4.1",
# ]
# ///
import marimo
__generated_with = "0.9.20"
app = marimo.App(width="medium")
@app.cell
def __(mo):
mo.md(r"""# Generative UI Chatbot""")
return
@app.cell
def __(mo):
_default_dataset = "hf://datasets/scikit-learn/Fish/Fish.csv"
dataset_input = mo.ui.text(value=_default_dataset, full_width=True)
return (dataset_input,)
@app.cell
def __(dataset_input, mo):
mo.md(f"""
This chatbot can answer questions about the following dataset: {dataset_input}
""")
return
@app.cell
def __(dataset_input, mo, pl):
# Grab a dataset
try:
df = pl.read_csv(dataset_input.value)
mo.output.replace(
mo.md(f"Loaded dataset with {len(df)} rows and {len(df.columns)} columns.")
)
except Exception as e:
df = pl.DataFrame()
mo.output.replace(
mo.md(f"""**Error loading dataset**:\n\n{e}""").callout(kind="danger")
)
return (df,)
@app.cell
def __():
import os
import marimo as mo
import polars as pl
return mo, os, pl
@app.cell
def __(mo, os):
api_key_input = mo.ui.text(
label="OpenAI API Key",
kind="password",
value=os.environ.get("OPENAI_API_KEY") or "",
)
return (api_key_input,)
@app.cell
def __(api_key_input):
api_key_input
return
@app.cell
def __(api_key_input, mo):
from openai import Client
mo.stop(not api_key_input.value, mo.md("_Missing API key_"))
client = Client(api_key=api_key_input.value)
return Client, client
@app.cell
def __(df, mo):
import ell
@ell.tool()
def chart_data(x_encoding: str, y_encoding: str, color: str):
"""Generate an altair chart"""
import altair as alt
return (
alt.Chart(df)
.mark_circle()
.encode(x=x_encoding, y=y_encoding, color=color)
.properties(width=500)
)
@ell.tool()
def filter_dataset(sql_query: str):
"""
Filter a polars dataframe using SQL. Please only use fields from the schema.
When referring to the table in SQL, call it 'data'.
"""
filtered = df.sql(sql_query, table_name="data")
return mo.ui.table(
filtered,
label=f"```sql\n{sql_query}\n```",
selection=None,
show_column_summaries=False,
)
return chart_data, ell, filter_dataset
@app.cell
def __(chart_data, client, df, ell, filter_dataset, mo):
@ell.complex(model="gpt-4o", tools=[chart_data, filter_dataset], client=client)
def analyze_dataset(prompt: str) -> str:
"""You are a data scientist that can analyze a dataset"""
return f"I have a dataset with schema: {df.schema}. \n{prompt}"
def my_model(messages):
response = analyze_dataset(messages)
if response.tool_calls:
return response.tool_calls[0]()
return response.text
mo.ui.chat(
my_model,
prompts=[
"Can you chart two columns of your choosing?",
"Can you find the min, max of all numeric fields?",
"What is the sum of {{column}}?",
],
)
return analyze_dataset, my_model
if __name__ == "__main__":
app.run()