Spaces:
Running
Running
import streamlit as st | |
from apis import generate_sql, generate_chart | |
from utils import format_sql | |
def main(): | |
st.title("π Wren AI Cloud API Demo - Chart Generation") | |
if "api_key" not in st.session_state or "project_id" not in st.session_state: | |
st.error("Please enter your API Key and Project ID in the sidebar of Home page to get started.") | |
return | |
if not st.session_state.api_key or not st.session_state.project_id: | |
st.error("Please enter your API Key and Project ID in the sidebar of Home page to get started.") | |
return | |
api_key = st.session_state.api_key | |
project_id = st.session_state.project_id | |
st.markdown('Using APIs: [SQL Generation](https://wrenai.readme.io/reference/cloud_post_generate-sql), [Chart Generation](https://wrenai.readme.io/reference/cloud_post_generate-vega-chart)') | |
# Sidebar for API configuration | |
with st.sidebar: | |
st.header("π§ Configuration") | |
sample_size = st.slider( | |
"Chart Sample Size", | |
min_value=100, | |
max_value=10000, | |
value=1000, | |
step=100, | |
help="Number of data points to include in charts" | |
) | |
# Initialize chat history | |
if "messages" not in st.session_state: | |
st.session_state.messages = [] | |
if "thread_id" not in st.session_state: | |
st.session_state.thread_id = "" | |
# Display chat history | |
for message in st.session_state.messages: | |
with st.chat_message(message["role"]): | |
if message["role"] == "user": | |
st.write(message["content"]) | |
else: | |
st.write(message["content"]) | |
if "sql" in message: | |
with st.expander("π Generated SQL Query", expanded=False): | |
st.code(format_sql(message["sql"]), language="sql") | |
if "vega_spec" in message: | |
try: | |
with st.expander("π Chart Specification", expanded=False): | |
st.json(message["vega_spec"]) | |
st.vega_lite_chart(message["vega_spec"]) | |
except Exception as e: | |
st.toast(f"Error rendering chart: {e}", icon="π¨") | |
# Chat input | |
if prompt := st.chat_input("Ask a question about your data..."): | |
# Add user message to chat history | |
st.session_state.messages.append({"role": "user", "content": prompt}) | |
# Display user message | |
with st.chat_message("user"): | |
st.write(prompt) | |
# Generate response | |
with st.chat_message("assistant"): | |
with st.spinner("Generating SQL query..."): | |
sql_response, error = generate_sql(api_key, project_id, prompt, st.session_state.thread_id) | |
if sql_response: | |
sql_query = sql_response.get("sql", "") | |
st.session_state.thread_id = sql_response.get("threadId", "") | |
if sql_query: | |
st.toast("SQL query generated successfully!", icon="π") | |
# Store the response | |
assistant_message = { | |
"role": "assistant", | |
"content": f"I've generated a SQL query for your question: '{prompt}'", | |
"sql": sql_query | |
} | |
st.session_state.messages.append(assistant_message) | |
st.write(assistant_message["content"]) | |
# Display SQL query | |
with st.expander("π Generated SQL Query", expanded=False): | |
st.code(format_sql(sql_query), language="sql") | |
# Generate chart | |
with st.spinner("Generating chart..."): | |
chart_response, error = generate_chart( | |
api_key, | |
project_id, | |
prompt, | |
sql_query, | |
thread_id=st.session_state.thread_id, | |
sample_size=sample_size, | |
) | |
if chart_response: | |
vega_spec = chart_response.get("vegaSpec", {}) | |
if vega_spec: | |
st.toast("Chart generated successfully!", icon="π") | |
assistant_message = { | |
"role": "assistant", | |
"content": f"I've generated a Chart for your question: '{prompt}'", | |
"vega_spec": vega_spec | |
} | |
st.session_state.messages.append(assistant_message) | |
st.write(assistant_message["content"]) | |
# Display chart | |
try: | |
# Show chart specification in expander | |
with st.expander("π Chart Specification", expanded=False): | |
st.json(vega_spec) | |
st.vega_lite_chart(vega_spec) | |
except Exception as e: | |
st.toast(f"Error rendering chart: {e}", icon="π¨") | |
else: | |
st.toast("Failed to generate chart. Please check your query and try again.", icon="π¨") | |
else: | |
st.toast(f"Failed to generate chart. Please check your query and try again.: {error}", icon="π¨") | |
else: | |
st.toast("No SQL query was generated. Please try rephrasing your question.", icon="π¨") | |
assistant_message = { | |
"role": "assistant", | |
"content": "I couldn't generate a SQL query for your question. Please try rephrasing it or make sure it's related to your data." | |
} | |
st.session_state.messages.append(assistant_message) | |
else: | |
st.toast(f"Error generating SQL: {error}", icon="π¨") | |
assistant_message = { | |
"role": "assistant", | |
"content": "Sorry, I couldn't process your request. Please check your API credentials and try again." | |
} | |
st.session_state.messages.append(assistant_message) | |
# Clear chat button | |
if st.sidebar.button("ποΈ Clear Chat History"): | |
st.session_state.messages = [] | |
st.session_state.thread_id = "" | |
st.rerun() | |
if __name__ == "__main__": | |
main() |