# /// script # requires-python = ">=3.12" # dependencies = [ # "ell-ai==0.0.14", # "marimo", # "openai==1.53.0", # "polars==1.12.0", # "altair==5.4.1", # ] # /// import marimo __generated_with = "0.9.20" app = marimo.App(width="medium") @app.cell def __(mo): mo.md(r"""# Generative UI Chatbot""") return @app.cell def __(mo): _default_dataset = "hf://datasets/scikit-learn/Fish/Fish.csv" dataset_input = mo.ui.text(value=_default_dataset, full_width=True) return (dataset_input,) @app.cell def __(dataset_input, mo): mo.md(f""" This chatbot can answer questions about the following dataset: {dataset_input} """) return @app.cell def __(dataset_input, mo, pl): # Grab a dataset try: df = pl.read_csv(dataset_input.value) mo.output.replace( mo.md(f"Loaded dataset with {len(df)} rows and {len(df.columns)} columns.") ) except Exception as e: df = pl.DataFrame() mo.output.replace( mo.md(f"""**Error loading dataset**:\n\n{e}""").callout(kind="danger") ) return (df,) @app.cell def __(): import os import marimo as mo import polars as pl return mo, os, pl @app.cell def __(mo, os): api_key_input = mo.ui.text( label="OpenAI API Key", kind="password", value=os.environ.get("OPENAI_API_KEY") or "", ) return (api_key_input,) @app.cell def __(api_key_input): api_key_input return @app.cell def __(api_key_input, mo): from openai import Client mo.stop(not api_key_input.value, mo.md("_Missing API key_")) client = Client(api_key=api_key_input.value) return Client, client @app.cell def __(df, mo): import ell @ell.tool() def chart_data(x_encoding: str, y_encoding: str, color: str): """Generate an altair chart""" import altair as alt return ( alt.Chart(df) .mark_circle() .encode(x=x_encoding, y=y_encoding, color=color) .properties(width=500) ) @ell.tool() def filter_dataset(sql_query: str): """ Filter a polars dataframe using SQL. Please only use fields from the schema. When referring to the table in SQL, call it 'data'. """ filtered = df.sql(sql_query, table_name="data") return mo.ui.table( filtered, label=f"```sql\n{sql_query}\n```", selection=None, show_column_summaries=False, ) return chart_data, ell, filter_dataset @app.cell def __(chart_data, client, df, ell, filter_dataset, mo): @ell.complex(model="gpt-4o", tools=[chart_data, filter_dataset], client=client) def analyze_dataset(prompt: str) -> str: """You are a data scientist that can analyze a dataset""" return f"I have a dataset with schema: {df.schema}. \n{prompt}" def my_model(messages): response = analyze_dataset(messages) if response.tool_calls: return response.tool_calls[0]() return response.text mo.ui.chat( my_model, prompts=[ "Can you chart two columns of your choosing?", "Can you find the min, max of all numeric fields?", "What is the sum of {{column}}?", ], ) return analyze_dataset, my_model if __name__ == "__main__": app.run()