ICIJ_SQL_agent / app.py
Didier Guillevic
Minor fixes.
ede8a06
"""app.py
Smolagents agent given an SQL tool over a SQLite database built with data files
from the Internation Consortium of Investigative Journalism (ICIJ.org).
Agentic framework:
- smolagents
Database:
- SQLite
Generation:
- Mistral
:author: Didier Guillevic
:date: 2025-01-12
"""
import gradio as gr
import icij_utils
import sqlalchemy
import smolagents
import os
import pathlib
#
# Init a SQLite database with the data files from ICIJ.org
#
ICIJ_LEAKS_DB_NAME = 'icij_leaks.db'
ICIJ_LEAKS_DATA_DIR = './icij_data'
# Remove existing database (if present), since we will recreate it below.
icij_db_path = pathlib.Path(ICIJ_LEAKS_DB_NAME)
icij_db_path.unlink(missing_ok=True)
# Load ICIJ data files into an SQLite database
loader = icij_utils.ICIJDataLoader(ICIJ_LEAKS_DB_NAME)
loader.load_all_files(ICIJ_LEAKS_DATA_DIR)
#
# Init an SQLAchemy instane (over the SQLite database)
#
db = icij_utils.ICIJDatabaseConnector(ICIJ_LEAKS_DB_NAME)
schema = db.get_full_database_schema()
#
# Build an SQL tool
#
schema = db.get_full_database_schema()
metadata = icij_utils.ICIJDatabaseMetadata()
tool_description = (
"Tool for querying the ICIJ offshore database containing financial data leaks. "
"This tool can execute SQL queries and return the results. "
"Beware that this tool's output is a string representation of the execution output.\n"
"It can use the following tables:"
)
# Add table documentation
for table, doc in metadata.TABLE_DOCS.items():
tool_description += f"\n\nTable: {table}\n"
tool_description += f"Description: {doc.strip()}\n"
tool_description += "Columns:\n"
# Add column documentation and types
if table in schema:
for col_name, col_type in schema[table].items():
col_doc = metadata.COLUMN_DOCS.get(table, {}).get(col_name, "No documentation available")
tool_description += f" - {col_name}: {col_type}: {col_doc}\n"
#tool_description += f" - {col_name}: {col_type}\n"
# Add source documentation
#tool_description += "\n\nSource IDs:\n"
#for source_id, descrip in metadata.SOURCE_IDS.items():
# tool_description += f"- {source_id}: {descrip}\n"
@smolagents.tool
def sql_tool(query: str) -> str:
"""Description to be set beloiw...
Args:
query: The query to perform. This should be correct SQL.
"""
output = ""
with db.get_engine().connect() as con:
rows = con.execute(sqlalchemy.text(query))
for row in rows:
output += "\n" + str(row)
return output
sql_tool.description = tool_description
#
# language models
#
default_model = smolagents.HfApiModel()
mistral_api_key = os.environ["MISTRAL_API_KEY"]
mistral_model_id = "mistral/codestral-latest"
mistral_model = smolagents.LiteLLMModel(
model_id=mistral_model_id,
api_key=mistral_api_key,
temperature=0.0
)
#
# Define the agent
#
agent = smolagents.CodeAgent(
tools=[sql_tool],
model=mistral_model
)
#
# Handler to extract the response's content
#
from typing import Union, Any
from dataclasses import is_dataclass
import json
class ResponseHandler:
@staticmethod
def extract_content(response: Any) -> str:
"""
Extract content from various types of agent responses.
Args:
response: The response from the agent, could be string, Message object, or dict
Returns:
str: The extracted content
"""
# If it's already a string, return it
if isinstance(response, str):
return response
# If it's a Message object
if hasattr(response, 'content') and isinstance(response.content, str):
return response.content
# If it's a dictionary (e.g., from json.loads())
if isinstance(response, dict) and 'content' in response:
return response['content']
# If it's a dataclass
if is_dataclass(response):
if hasattr(response, 'content'):
return response.content
# If it's JSON string
if isinstance(response, str):
try:
parsed = json.loads(response)
if isinstance(parsed, dict) and 'content' in parsed:
return parsed['content']
except json.JSONDecodeError:
pass
# If we can't determine the type, return the string representation
return str(response)
handler = ResponseHandler()
def generate_response(query: str) -> str:
"""Generate a response given query.
Args:
- query: the question from the user
Returns:
- the response from the agent having access to a database over the ICIJ
data and a large language model.
"""
agent_output = agent.run(query)
# At times, the response appears to be a class instance with a 'content'
# part. Hence, we will pass the agent's response to some handler that will
# extract the response's content.
return handler.extract_content(agent_output)
#
# User interface
#
with gr.Blocks() as demo:
gr.Markdown("""
# SQL agent
Database: ICIJ data on offshore financial data leaks. Very early "fast" prorotyping.
""")
# Inputs: question
question = gr.Textbox(
label="Question to answer",
placeholder=""
)
# Response
response = gr.Textbox(
label="Response",
placeholder=""
)
# Button
with gr.Row():
response_button = gr.Button("Submit", variant='primary')
clear_button = gr.Button("Clear", variant='secondary')
# Example questions given default provided PDF file
with gr.Accordion("Sample questions", open=False):
gr.Examples(
[
[
(
"Can you list the entities with an address in Canada? "
"Please give the name of the entity an its address."
),
],
[
"Are there any entities located on Montreal, Canada?",
]
],
inputs=[question,],
outputs=[response,],
fn=generate_response,
cache_examples=False,
label="Sample questions"
)
# Documentation
with gr.Accordion("Documentation", open=False):
gr.Markdown("""
- Agentic framework: smolagents
- Data: icij.org
- Database: SQLite, SQLAlchemy
- Generation: Mistral
- Examples: Generated using Claude.ai
""")
# Click actions
response_button.click(
fn=generate_response,
inputs=[question,],
outputs=[response,]
)
clear_button.click(
fn=lambda: ('', ''),
inputs=[],
outputs=[question, response]
)
demo.launch(show_api=False)