wrenai-cloud-api-demo / src /pages /3_Chart_Generation.py
cyyeh's picture
update
99a7fbb
import streamlit as st
from apis import generate_sql, generate_chart
from utils import format_sql
def main():
st.title("πŸ“Š Wren AI Cloud API Demo - Chart Generation")
if "api_key" not in st.session_state or "project_id" not in st.session_state:
st.error("Please enter your API Key and Project ID in the sidebar of Home page to get started.")
return
if not st.session_state.api_key or not st.session_state.project_id:
st.error("Please enter your API Key and Project ID in the sidebar of Home page to get started.")
return
api_key = st.session_state.api_key
project_id = st.session_state.project_id
st.markdown('Using APIs: [SQL Generation](https://wrenai.readme.io/reference/cloud_post_generate-sql), [Chart Generation](https://wrenai.readme.io/reference/cloud_post_generate-vega-chart)')
# Sidebar for API configuration
with st.sidebar:
st.header("πŸ”§ Configuration")
sample_size = st.slider(
"Chart Sample Size",
min_value=100,
max_value=10000,
value=1000,
step=100,
help="Number of data points to include in charts"
)
# Initialize chat history
if "messages" not in st.session_state:
st.session_state.messages = []
if "thread_id" not in st.session_state:
st.session_state.thread_id = ""
# Display chat history
for message in st.session_state.messages:
with st.chat_message(message["role"]):
if message["role"] == "user":
st.write(message["content"])
else:
st.write(message["content"])
if "sql" in message:
with st.expander("πŸ“ Generated SQL Query", expanded=False):
st.code(format_sql(message["sql"]), language="sql")
if "vega_spec" in message:
try:
with st.expander("πŸ“Š Chart Specification", expanded=False):
st.json(message["vega_spec"])
st.vega_lite_chart(message["vega_spec"])
except Exception as e:
st.toast(f"Error rendering chart: {e}", icon="🚨")
# Chat input
if prompt := st.chat_input("Ask a question about your data..."):
# Add user message to chat history
st.session_state.messages.append({"role": "user", "content": prompt})
# Display user message
with st.chat_message("user"):
st.write(prompt)
# Generate response
with st.chat_message("assistant"):
with st.spinner("Generating SQL query..."):
sql_response, error = generate_sql(api_key, project_id, prompt, st.session_state.thread_id)
if sql_response:
sql_query = sql_response.get("sql", "")
st.session_state.thread_id = sql_response.get("threadId", "")
if sql_query:
st.toast("SQL query generated successfully!", icon="πŸŽ‰")
# Store the response
assistant_message = {
"role": "assistant",
"content": f"I've generated a SQL query for your question: '{prompt}'",
"sql": sql_query
}
st.session_state.messages.append(assistant_message)
st.write(assistant_message["content"])
# Display SQL query
with st.expander("πŸ“ Generated SQL Query", expanded=False):
st.code(format_sql(sql_query), language="sql")
# Generate chart
with st.spinner("Generating chart..."):
chart_response, error = generate_chart(
api_key,
project_id,
prompt,
sql_query,
thread_id=st.session_state.thread_id,
sample_size=sample_size,
)
if chart_response:
vega_spec = chart_response.get("vegaSpec", {})
if vega_spec:
st.toast("Chart generated successfully!", icon="πŸŽ‰")
assistant_message = {
"role": "assistant",
"content": f"I've generated a Chart for your question: '{prompt}'",
"vega_spec": vega_spec
}
st.session_state.messages.append(assistant_message)
st.write(assistant_message["content"])
# Display chart
try:
# Show chart specification in expander
with st.expander("πŸ“Š Chart Specification", expanded=False):
st.json(vega_spec)
st.vega_lite_chart(vega_spec)
except Exception as e:
st.toast(f"Error rendering chart: {e}", icon="🚨")
else:
st.toast("Failed to generate chart. Please check your query and try again.", icon="🚨")
else:
st.toast(f"Failed to generate chart. Please check your query and try again.: {error}", icon="🚨")
else:
st.toast("No SQL query was generated. Please try rephrasing your question.", icon="🚨")
assistant_message = {
"role": "assistant",
"content": "I couldn't generate a SQL query for your question. Please try rephrasing it or make sure it's related to your data."
}
st.session_state.messages.append(assistant_message)
else:
st.toast(f"Error generating SQL: {error}", icon="🚨")
assistant_message = {
"role": "assistant",
"content": "Sorry, I couldn't process your request. Please check your API credentials and try again."
}
st.session_state.messages.append(assistant_message)
# Clear chat button
if st.sidebar.button("πŸ—‘οΈ Clear Chat History"):
st.session_state.messages = []
st.session_state.thread_id = ""
st.rerun()
if __name__ == "__main__":
main()