Cashy / src /ui.py
GitHub Actions
Deploy to HF Spaces
17a78b5
import json
import uuid
import time
import logging
import gradio as gr
from langchain_core.messages import HumanMessage
from langgraph.types import Command
from src.config import settings
from src.agent.nodes import reset_model
from src.db.connection import get_connection
logger = logging.getLogger("cashy.ui")
def list_threads():
"""Get all thread_ids from the checkpoints table, most recent first."""
try:
with get_connection() as conn:
with conn.cursor() as cur:
cur.execute("""
SELECT thread_id, MAX(checkpoint_id) AS latest
FROM checkpoints
GROUP BY thread_id
ORDER BY latest DESC
""")
return [row[0] for row in cur.fetchall()]
except Exception as e:
logger.warning("Could not list threads: %s", e)
return []
def load_thread_history(agent, thread_id):
"""Load messages from a thread and convert to Gradio chatbot format."""
config = {"configurable": {"thread_id": thread_id}}
state = agent.get_state(config)
messages = state.values.get("messages", [])
history = []
for msg in messages:
if msg.type == "human":
history.append({"role": "user", "content": msg.content})
elif msg.type == "ai" and msg.content:
history.append({"role": "assistant", "content": msg.content})
elif msg.type == "tool":
try:
data = json.loads(msg.content)
if isinstance(data, dict) and "chart_path" in data:
history.append({"role": "assistant", "content": {"path": data["chart_path"]}})
except (json.JSONDecodeError, TypeError):
pass
return history
def get_thread_title(agent, thread_id):
"""Extract first user message as thread title (truncated to 50 chars)."""
config = {"configurable": {"thread_id": thread_id}}
try:
state = agent.get_state(config)
for msg in state.values.get("messages", []):
if msg.type == "human":
title = msg.content[:50]
return title + "..." if len(msg.content) > 50 else title
except Exception:
pass
return thread_id[:12]
def get_thread_choices(agent, min_user_messages=2):
"""Build dropdown choices as (title, thread_id) tuples.
Only includes threads with at least min_user_messages user messages,
filtering out orphan single-exchange threads.
"""
threads = list_threads()
choices = []
for tid in threads:
config = {"configurable": {"thread_id": tid}}
try:
state = agent.get_state(config)
messages = state.values.get("messages", [])
user_msgs = [m for m in messages if m.type == "human"]
if len(user_msgs) < min_user_messages:
continue
first_msg = user_msgs[0].content[:50]
title = first_msg + "..." if len(user_msgs[0].content) > 50 else first_msg
choices.append((title, tid))
except Exception:
continue
return choices
def format_confirmation(interrupt_data):
"""Format an interrupt payload as a user-friendly confirmation message."""
action = interrupt_data.get("action", "unknown")
message = interrupt_data.get("message", "Confirm this action?")
# Build a readable action label
action_labels = {
"create_transaction": "Create Transaction",
"update_transaction": "Update Transaction",
"delete_transaction": "Delete Transaction",
}
label = action_labels.get(action, action.replace("_", " ").title())
lines = [f"**Confirm: {label}**\n"]
# Show details as a table
details = interrupt_data.get("details", {})
if details:
lines.append("| Field | Value |")
lines.append("|-------|-------|")
for key, value in details.items():
display_key = key.replace("_", " ").title()
if key == "amount":
display_value = f"${value:,.2f}"
else:
display_value = str(value)
lines.append(f"| {display_key} | {display_value} |")
lines.append("")
# Show changes for update operations
changes = interrupt_data.get("changes", {})
if changes:
lines.append("**Changes:**\n")
lines.append("| Field | New Value |")
lines.append("|-------|-----------|")
for key, value in changes.items():
display_key = key.replace("_", " ").title()
if key == "amount":
display_value = f"${value:,.2f}"
else:
display_value = str(value)
lines.append(f"| {display_key} | {display_value} |")
lines.append("")
# Show current values for update operations
current = interrupt_data.get("current", {})
if current:
lines.append("**Current values:**\n")
lines.append("| Field | Value |")
lines.append("|-------|-------|")
for key, value in current.items():
display_key = key.replace("_", " ").title()
if key == "amount":
display_value = f"${value:,.2f}"
else:
display_value = str(value)
lines.append(f"| {display_key} | {display_value} |")
lines.append("")
lines.append("Reply **yes** to confirm or **no** to cancel.")
return "\n".join(lines)
WELCOME_MESSAGE_DEMO = """\
Hi! I'm **Cashy**, your AI financial advisor.
I'm connected to a demo database with **4 months of financial data** for a US-based freelance web developer:
- **11 accounts** — Chase, PayPal, Stripe, Wise, Marcus, Fidelity, credit cards, and cash
- **233 transactions** — client invoices, business expenses, personal spending, transfers
- **20 budgets** — monthly spending limits across 35 categories
**Ready to go** with the free tier, or switch to your own LLM provider in the sidebar.
Ask me anything about your finances. Here are some ideas:
1. **"What accounts do I have?"** — See all accounts and balances
2. **"How much did I spend this month?"** — Spending breakdown by category
3. **"How much did I earn from clients in January?"** — Income tracking
4. **"Am I over budget on anything?"** — Budget vs. actual comparison
5. **"Show me my last 10 transactions"** — Recent transaction history
"""
WELCOME_MESSAGE_PERSONAL = """\
Hi! I'm **Cashy**, your AI financial advisor.
I'm connected to your personal financial database. Ask me anything about your accounts, transactions, spending, or budgets.
Here are some things I can help with:
1. **"What accounts do I have?"** — See all accounts and balances
2. **"How much did I spend this month?"** — Spending breakdown by category
3. **"Show me my last 10 transactions"** — Recent transaction history
4. **"Am I over budget on anything?"** — Budget vs. actual comparison
5. **"Show me a chart of my spending"** — Visual spending analysis
"""
PROVIDERS = ["free-tier", "openai", "anthropic", "google", "huggingface"]
DEFAULT_MODELS = {
"free-tier": "Qwen/Qwen2.5-7B-Instruct",
"openai": "gpt-5-mini",
"anthropic": "claude-sonnet-4-20250514",
"google": "gemini-2.5-flash",
"huggingface": "meta-llama/Llama-3.3-70B-Instruct",
}
HF_INFERENCE_PROVIDERS = [
"cerebras",
"cohere",
"featherless-ai",
"fireworks-ai",
"groq",
"hf-inference",
"hyperbolic",
"nebius",
"novita",
"nscale",
"ovhcloud",
"sambanova",
"scaleway",
"together",
]
def create_ui(agent):
"""Create the Gradio chat UI with compact reference sidebar."""
current_provider = settings.resolved_provider or "openai"
has_provider = settings.resolved_provider is not None
is_free = current_provider == "free-tier"
is_demo = settings.app_mode == "demo"
mode_label = "Demo" if is_demo else "Personal"
welcome_text = WELCOME_MESSAGE_DEMO if is_demo else WELCOME_MESSAGE_PERSONAL
FREE_TIER_DISCLAIMER = (
"\n\n---\n*Free tier uses a lightweight open-source model. "
"For better results, switch to OpenAI, Anthropic, or Google in the sidebar.*"
)
def respond(message, history, thread_id, pending):
config = {"configurable": {"thread_id": thread_id}}
logger.info(">>> User [thread=%s]: %s", thread_id[:8], message)
start = time.time()
try:
# --- Resume from interrupt (user confirming/rejecting) ---
if pending:
approved = message.strip().lower() in ("yes", "approve", "confirm", "y")
logger.info("Interrupt response: %s", "approved" if approved else "rejected")
result = agent.invoke(Command(resume={"approved": approved}), config)
response = result["messages"][-1].content
elapsed = time.time() - start
logger.info("<<< Response [%.1fs]: %s", elapsed, response[:120])
if settings.resolved_provider == "free-tier":
response += FREE_TIER_DISCLAIMER
history.append({"role": "user", "content": message})
history.append({"role": "assistant", "content": response})
return "", history, thread_id, False
# --- Normal flow ---
# Count existing messages so we only scan new ones for charts
state = agent.get_state(config)
prev_count = len(state.values.get("messages", []))
result = agent.invoke(
{"messages": [HumanMessage(content=message)]},
config,
)
# --- Check for interrupt (write operation needs confirmation) ---
if "__interrupt__" in result:
interrupt_data = result["__interrupt__"][0].value
confirmation_msg = format_confirmation(interrupt_data)
elapsed = time.time() - start
logger.info("<<< Interrupt [%.1fs]: %s", elapsed, interrupt_data.get("action", "unknown"))
history.append({"role": "user", "content": message})
history.append({"role": "assistant", "content": confirmation_msg})
return "", history, thread_id, True
# --- Normal response (no interrupt) ---
response = result["messages"][-1].content
elapsed = time.time() - start
logger.info("<<< Response [%.1fs]: %s", elapsed, response[:120])
# Scan only NEW messages for chart images (skip prior history)
chart_paths = []
for msg in result["messages"][prev_count:]:
if hasattr(msg, "type") and msg.type == "tool":
try:
data = json.loads(msg.content)
if isinstance(data, dict) and "chart_path" in data:
chart_paths.append(data["chart_path"])
except (json.JSONDecodeError, TypeError):
pass
if settings.resolved_provider == "free-tier":
response += FREE_TIER_DISCLAIMER
history.append({"role": "user", "content": message})
history.append({"role": "assistant", "content": response})
for path in chart_paths:
history.append({"role": "assistant", "content": {"path": path}})
return "", history, thread_id, False
except Exception as e:
logger.error("<<< Error: %s", e)
error_str = str(e).lower()
if "ssl" in error_str or "connection" in error_str and "closed" in error_str:
msg = (
"The database connection was lost (the cloud database likely went to sleep). "
"Please try again in a few seconds -- it should reconnect automatically. "
"If the issue persists, restart the Space from Settings."
)
else:
msg = f"**Error:** {e}"
history.append({"role": "user", "content": message})
history.append({"role": "assistant", "content": msg})
return "", history, thread_id, False
def switch_provider(provider):
settings.llm_provider = provider
settings.model_name = "" # reset to default for new provider
reset_model()
model = DEFAULT_MODELS.get(provider, "default")
is_free = provider == "free-tier"
is_hf = provider == "huggingface"
show_byok = not is_free # free-tier hides API key, model, HF provider
logger.info("Provider switched to: %s (%s)", provider, model)
if is_free:
status = f"Using **Free Tier** ({model}) -- no API key needed"
else:
status = f"Switched to **{provider.capitalize()}** ({model})"
return (
status,
gr.update(visible=show_byok, value=""),
gr.update(visible=show_byok, placeholder=f"Default: {model}", value=""),
gr.update(visible=is_hf),
gr.update(visible=show_byok),
)
def set_api_key(provider, api_key, model_name, hf_provider):
key = api_key.strip()
if not key:
return "No key entered."
key_fields = {
"openai": "openai_api_key",
"anthropic": "anthropic_api_key",
"google": "google_api_key",
"huggingface": "hf_token",
}
field = key_fields.get(provider)
if not field:
return f"Unknown provider: {provider}"
setattr(settings, field, key)
settings.llm_provider = provider
if model_name.strip():
settings.model_name = model_name.strip()
if provider == "huggingface" and hf_provider:
settings.hf_inference_provider = hf_provider
reset_model()
model = settings.model_name or DEFAULT_MODELS.get(provider, "default")
logger.info("API key set for provider: %s (%s)", provider, model)
return f"API key saved. Using **{provider.capitalize()}** ({model})."
def set_model(provider, model_name):
name = model_name.strip()
settings.model_name = name
reset_model()
model = name or DEFAULT_MODELS.get(provider, "default")
logger.info("Model changed to: %s", model)
return f"Model set to **{model}**."
def set_hf_provider(hf_provider):
settings.hf_inference_provider = hf_provider
reset_model()
logger.info("HF inference provider changed to: %s", hf_provider)
return f"Inference provider set to **{hf_provider}**."
welcome = [{"role": "assistant", "content": welcome_text}]
if is_demo:
theme = gr.themes.Glass(primary_hue="indigo")
else:
theme = gr.themes.Default()
with gr.Blocks(title="Cashy - AI Financial Advisor") as demo:
gr.Markdown("# Cashy — AI Financial Advisor")
session_thread_id = gr.State(value=lambda: str(uuid.uuid4()))
pending_interrupt = gr.State(value=False)
with gr.Row():
with gr.Column(scale=3):
chatbot = gr.Chatbot(
value=welcome,
height=600,
buttons=["copy"],
)
with gr.Row():
msg = gr.Textbox(
placeholder="Ask about your finances...",
show_label=False,
scale=9,
)
submit_btn = gr.Button("Send", variant="primary", scale=1)
with gr.Column(scale=1, min_width=250):
new_chat_btn = gr.Button("+ New Chat", variant="secondary")
if not is_demo:
with gr.Accordion("Chat History", open=False):
thread_dropdown = gr.Dropdown(
choices=[],
label="Previous chats",
interactive=True,
)
load_btn = gr.Button("Load Chat")
gr.Markdown(f"**Mode:** {mode_label}")
gr.Markdown("---")
if is_demo:
gr.Markdown(
"**Demo Data** · Oct 2025 – Jan 2026 · USD\n\n"
"11 accounts · 233 transactions · 20 budgets"
)
gr.Markdown("---")
gr.Markdown(
"**Capabilities**\n\n"
"- Check account balances\n"
"- Analyze spending by category\n"
"- Search transaction history\n"
"- Compare budgets vs. actual\n"
"- Create, update, delete transactions\n"
"- Run custom SQL queries"
)
gr.Markdown("---")
if is_demo:
gr.Markdown(
"**Try asking**\n\n"
'*"What\'s the balance on Chase Business?"*\n\n'
'*"Am I over budget on anything?"*\n\n'
'*"I need a $1,500 laptop -- can I afford it?"*\n\n'
'*"Show me a pie chart of my spending"*\n\n'
'*"Chart my budget vs actual for January"*'
)
else:
gr.Markdown(
"**Try asking**\n\n"
'*"What accounts do I have?"*\n\n'
'*"How much did I spend this month?"*\n\n'
'*"Show me a chart of my spending"*'
)
gr.Markdown("---")
provider_dropdown = gr.Dropdown(
choices=PROVIDERS,
value=current_provider,
label="LLM Provider",
)
with gr.Row():
api_key_input = gr.Textbox(
label="API Key",
placeholder="Paste your API key here...",
type="password",
scale=4,
visible=not is_free,
)
save_key_btn = gr.Button("Save", variant="primary", scale=1, visible=not is_free)
model_name_input = gr.Textbox(
label="Model Name (optional)",
placeholder=f"Default: {DEFAULT_MODELS.get(current_provider, '')}",
value="",
visible=not is_free,
)
hf_provider_dropdown = gr.Dropdown(
choices=HF_INFERENCE_PROVIDERS,
value=settings.hf_inference_provider,
label="Inference Provider",
visible=current_provider == "huggingface",
)
if is_free:
status_text = f"Using **Free Tier** ({DEFAULT_MODELS['free-tier']}) -- no API key needed"
elif has_provider:
status_text = f"Using **{current_provider.capitalize()}** ({DEFAULT_MODELS.get(current_provider, 'default')})"
else:
status_text = "No API key configured -- select a provider and enter one above"
provider_status = gr.Markdown(status_text)
# --- Event handlers ---
# Chat events (now include pending_interrupt state)
chat_inputs = [msg, chatbot, session_thread_id, pending_interrupt]
chat_outputs = [msg, chatbot, session_thread_id, pending_interrupt]
if is_demo:
def new_chat_demo():
new_id = str(uuid.uuid4())
logger.info("New chat started [thread=%s]", new_id[:8])
return new_id, welcome, "", False
msg.submit(respond, chat_inputs, chat_outputs)
submit_btn.click(respond, chat_inputs, chat_outputs)
new_chat_btn.click(
new_chat_demo, [], [session_thread_id, chatbot, msg, pending_interrupt]
)
else:
def new_chat():
new_id = str(uuid.uuid4())
choices = get_thread_choices(agent)
logger.info("New chat started [thread=%s]", new_id[:8])
return new_id, welcome, "", gr.update(choices=choices), False
def load_thread(selected_thread_id):
if not selected_thread_id:
return gr.update(), gr.update(), False
history = load_thread_history(agent, selected_thread_id)
logger.info("Loaded thread %s (%d messages)", selected_thread_id[:8], len(history))
return selected_thread_id, history, False
def refresh_threads():
choices = get_thread_choices(agent)
return gr.update(choices=choices)
msg.submit(respond, chat_inputs, chat_outputs).then(
refresh_threads, [], [thread_dropdown]
)
submit_btn.click(respond, chat_inputs, chat_outputs).then(
refresh_threads, [], [thread_dropdown]
)
new_chat_btn.click(
new_chat, [], [session_thread_id, chatbot, msg, thread_dropdown, pending_interrupt]
)
load_btn.click(
load_thread, [thread_dropdown], [session_thread_id, chatbot, pending_interrupt]
)
provider_dropdown.change(
switch_provider, [provider_dropdown],
[provider_status, api_key_input, model_name_input, hf_provider_dropdown, save_key_btn],
)
api_key_inputs = [provider_dropdown, api_key_input, model_name_input, hf_provider_dropdown]
api_key_input.submit(set_api_key, api_key_inputs, [provider_status])
save_key_btn.click(set_api_key, api_key_inputs, [provider_status])
model_name_input.submit(set_model, [provider_dropdown, model_name_input], [provider_status])
hf_provider_dropdown.change(set_hf_provider, [hf_provider_dropdown], [provider_status])
# Populate thread list on page load (personal mode only)
if not is_demo:
demo.load(refresh_threads, [], [thread_dropdown])
return demo, theme