sqlbot / db /memory.py
jashdoshi77's picture
added endpoints
65fba82
"""Conversation memory stored in PostgreSQL (Neon).
Keeps the last N turns per conversation so the AI can
use recent context for follow‑up questions.
"""
from __future__ import annotations
import json
from typing import Any, List, Dict
from sqlalchemy import text
from db.connection import get_engine
_TABLE_CREATED = False
def _ensure_table() -> None:
"""Create the chat_history table if it doesn't exist, and add query_result column if missing."""
global _TABLE_CREATED
if _TABLE_CREATED:
return
engine = get_engine()
with engine.begin() as conn:
conn.execute(text(
"""
CREATE TABLE IF NOT EXISTS chat_history (
id BIGSERIAL PRIMARY KEY,
conversation_id TEXT NOT NULL,
question TEXT NOT NULL,
answer TEXT NOT NULL,
sql_query TEXT,
query_result TEXT,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
"""
))
# Migrate existing tables that don't have the query_result column yet
conn.execute(text(
"""
ALTER TABLE chat_history ADD COLUMN IF NOT EXISTS query_result TEXT;
"""
))
_TABLE_CREATED = True
def add_turn(
conversation_id: str,
question: str,
answer: str,
sql_query: str | None,
query_result: list | None = None,
) -> None:
"""Append a single Q/A turn to the history."""
_ensure_table()
engine = get_engine()
result_json = json.dumps(query_result, default=str) if query_result else None
insert_stmt = text(
"""
INSERT INTO chat_history (conversation_id, question, answer, sql_query, query_result)
VALUES (:conversation_id, :question, :answer, :sql_query, :query_result)
"""
)
with engine.begin() as conn:
conn.execute(
insert_stmt,
{
"conversation_id": conversation_id,
"question": question,
"answer": answer,
"sql_query": sql_query,
"query_result": result_json,
},
)
def delete_turn(turn_id: int) -> None:
"""Delete a single chat history turn by its id."""
_ensure_table()
engine = get_engine()
with engine.begin() as conn:
conn.execute(
text("DELETE FROM chat_history WHERE id = :id"),
{"id": turn_id},
)
def get_full_history(conversation_id: str) -> List[Dict[str, Any]]:
"""Return ALL turns for a conversation (oldest first) for the sidebar display."""
_ensure_table()
engine = get_engine()
query = text(
"""
SELECT id, question, answer, sql_query, query_result, created_at
FROM chat_history
WHERE conversation_id = :conversation_id
ORDER BY created_at ASC
"""
)
with engine.connect() as conn:
rows = conn.execute(
query, {"conversation_id": conversation_id}
).mappings().all()
result = []
for r in rows:
row = dict(r)
# Deserialize query_result JSON string back to a list
if row.get("query_result"):
try:
row["query_result"] = json.loads(row["query_result"])
except (json.JSONDecodeError, TypeError):
row["query_result"] = None
result.append(row)
return result
def get_turn_by_id(turn_id: int) -> Dict[str, Any] | None:
"""Return a single history turn by its primary key id."""
_ensure_table()
engine = get_engine()
query = text(
"""
SELECT id, question, answer, sql_query, query_result, created_at
FROM chat_history
WHERE id = :id
"""
)
with engine.connect() as conn:
row = conn.execute(query, {"id": turn_id}).mappings().first()
if row is None:
return None
result = dict(row)
if result.get("query_result"):
try:
result["query_result"] = json.loads(result["query_result"])
except (json.JSONDecodeError, TypeError):
result["query_result"] = None
return result
def get_recent_history(conversation_id: str, limit: int = 5) -> List[Dict[str, Any]]:
"""Return the most recent `limit` turns for a conversation (oldest first)."""
_ensure_table()
engine = get_engine()
query = text(
"""
SELECT question, answer, sql_query, created_at
FROM chat_history
WHERE conversation_id = :conversation_id
ORDER BY created_at DESC
LIMIT :limit
"""
)
with engine.connect() as conn:
rows = conn.execute(
query, {"conversation_id": conversation_id, "limit": limit}
).mappings().all()
# Reverse so caller sees oldest → newest
return list(reversed([dict(r) for r in rows]))