Spaces:
Sleeping
Sleeping
"""app.py | |
Smolagents agent given an SQL tool over a SQLite database built with data files | |
from the Internation Consortium of Investigative Journalism (ICIJ.org). | |
Agentic framework: | |
- smolagents | |
Database: | |
- SQLite | |
Generation: | |
- Mistral | |
:author: Didier Guillevic | |
:date: 2025-01-12 | |
""" | |
import gradio as gr | |
import icij_utils | |
import sqlalchemy | |
import smolagents | |
import os | |
import pathlib | |
# | |
# Init a SQLite database with the data files from ICIJ.org | |
# | |
ICIJ_LEAKS_DB_NAME = 'icij_leaks.db' | |
ICIJ_LEAKS_DATA_DIR = './icij_data' | |
# Remove existing database (if present), since we will recreate it below. | |
icij_db_path = pathlib.Path(ICIJ_LEAKS_DB_NAME) | |
icij_db_path.unlink(missing_ok=True) | |
# Load ICIJ data files into an SQLite database | |
loader = icij_utils.ICIJDataLoader(ICIJ_LEAKS_DB_NAME) | |
loader.load_all_files(ICIJ_LEAKS_DATA_DIR) | |
# | |
# Init an SQLAchemy instane (over the SQLite database) | |
# | |
db = icij_utils.ICIJDatabaseConnector(ICIJ_LEAKS_DB_NAME) | |
schema = db.get_full_database_schema() | |
# | |
# Build an SQL tool | |
# | |
schema = db.get_full_database_schema() | |
metadata = icij_utils.ICIJDatabaseMetadata() | |
tool_description = ( | |
"Tool for querying the ICIJ offshore database containing financial data leaks. " | |
"This tool can execute SQL queries and return the results. " | |
"Beware that this tool's output is a string representation of the execution output.\n" | |
"It can use the following tables:" | |
) | |
# Add table documentation | |
for table, doc in metadata.TABLE_DOCS.items(): | |
tool_description += f"\n\nTable: {table}\n" | |
tool_description += f"Description: {doc.strip()}\n" | |
tool_description += "Columns:\n" | |
# Add column documentation and types | |
if table in schema: | |
for col_name, col_type in schema[table].items(): | |
col_doc = metadata.COLUMN_DOCS.get(table, {}).get(col_name, "No documentation available") | |
tool_description += f" - {col_name}: {col_type}: {col_doc}\n" | |
#tool_description += f" - {col_name}: {col_type}\n" | |
# Add source documentation | |
#tool_description += "\n\nSource IDs:\n" | |
#for source_id, descrip in metadata.SOURCE_IDS.items(): | |
# tool_description += f"- {source_id}: {descrip}\n" | |
def sql_tool(query: str) -> str: | |
"""Description to be set beloiw... | |
Args: | |
query: The query to perform. This should be correct SQL. | |
""" | |
output = "" | |
with db.get_engine().connect() as con: | |
rows = con.execute(sqlalchemy.text(query)) | |
for row in rows: | |
output += "\n" + str(row) | |
return output | |
sql_tool.description = tool_description | |
# | |
# language models | |
# | |
default_model = smolagents.HfApiModel() | |
mistral_api_key = os.environ["MISTRAL_API_KEY"] | |
mistral_model_id = "mistral/codestral-latest" | |
mistral_model = smolagents.LiteLLMModel( | |
model_id=mistral_model_id, | |
api_key=mistral_api_key, | |
temperature=0.0 | |
) | |
# | |
# Define the agent | |
# | |
agent = smolagents.CodeAgent( | |
tools=[sql_tool], | |
model=mistral_model | |
) | |
# | |
# Handler to extract the response's content | |
# | |
from typing import Union, Any | |
from dataclasses import is_dataclass | |
import json | |
class ResponseHandler: | |
def extract_content(response: Any) -> str: | |
""" | |
Extract content from various types of agent responses. | |
Args: | |
response: The response from the agent, could be string, Message object, or dict | |
Returns: | |
str: The extracted content | |
""" | |
# If it's already a string, return it | |
if isinstance(response, str): | |
return response | |
# If it's a Message object | |
if hasattr(response, 'content') and isinstance(response.content, str): | |
return response.content | |
# If it's a dictionary (e.g., from json.loads()) | |
if isinstance(response, dict) and 'content' in response: | |
return response['content'] | |
# If it's a dataclass | |
if is_dataclass(response): | |
if hasattr(response, 'content'): | |
return response.content | |
# If it's JSON string | |
if isinstance(response, str): | |
try: | |
parsed = json.loads(response) | |
if isinstance(parsed, dict) and 'content' in parsed: | |
return parsed['content'] | |
except json.JSONDecodeError: | |
pass | |
# If we can't determine the type, return the string representation | |
return str(response) | |
handler = ResponseHandler() | |
def generate_response(query: str) -> str: | |
"""Generate a response given query. | |
Args: | |
- query: the question from the user | |
Returns: | |
- the response from the agent having access to a database over the ICIJ | |
data and a large language model. | |
""" | |
agent_output = agent.run(query) | |
# At times, the response appears to be a class instance with a 'content' | |
# part. Hence, we will pass the agent's response to some handler that will | |
# extract the response's content. | |
return handler.extract_content(agent_output) | |
# | |
# User interface | |
# | |
with gr.Blocks() as demo: | |
gr.Markdown(""" | |
# SQL agent | |
Database: ICIJ data on offshore financial data leaks. Very early "fast" prorotyping. | |
""") | |
# Inputs: question | |
question = gr.Textbox( | |
label="Question to answer", | |
placeholder="" | |
) | |
# Response | |
response = gr.Textbox( | |
label="Response", | |
placeholder="" | |
) | |
# Button | |
with gr.Row(): | |
response_button = gr.Button("Submit", variant='primary') | |
clear_button = gr.Button("Clear", variant='secondary') | |
# Example questions given default provided PDF file | |
with gr.Accordion("Sample questions", open=False): | |
gr.Examples( | |
[ | |
[ | |
( | |
"Can you list the entities with an address in Canada? " | |
"Please give the name of the entity an its address." | |
), | |
], | |
[ | |
"Are there any entities located on Montreal, Canada?", | |
] | |
], | |
inputs=[question,], | |
outputs=[response,], | |
fn=generate_response, | |
cache_examples=False, | |
label="Sample questions" | |
) | |
# Documentation | |
with gr.Accordion("Documentation", open=False): | |
gr.Markdown(""" | |
- Agentic framework: smolagents | |
- Data: icij.org | |
- Database: SQLite, SQLAlchemy | |
- Generation: Mistral | |
- Examples: Generated using Claude.ai | |
""") | |
# Click actions | |
response_button.click( | |
fn=generate_response, | |
inputs=[question,], | |
outputs=[response,] | |
) | |
clear_button.click( | |
fn=lambda: ('', ''), | |
inputs=[], | |
outputs=[question, response] | |
) | |
demo.launch(show_api=False) |