|
import glob |
|
import os |
|
import sqlite3 |
|
import time |
|
from abc import ABC, abstractmethod |
|
from functools import lru_cache |
|
from typing import Any, List, Optional |
|
|
|
import requests |
|
from huggingface_hub import snapshot_download |
|
from requests.exceptions import ConnectionError, ReadTimeout |
|
|
|
from .logging_utils import get_logger |
|
from .types import SQLDatabase |
|
|
|
logger = get_logger() |
|
|
|
|
|
class DatabaseConnector(ABC): |
|
"""Abstract base class for database connectors.""" |
|
|
|
def __init__(self, db_config: SQLDatabase): |
|
self.db_config = db_config |
|
self.databases_folder = os.path.join( |
|
os.environ.get("UNITXT_TEXT2SQL_CACHE", "cache/text2sql"), "databases" |
|
) |
|
os.makedirs(self.databases_folder, exist_ok=True) |
|
|
|
@abstractmethod |
|
def get_table_schema( |
|
self, |
|
) -> str: |
|
"""Abstract method to get database schema.""" |
|
pass |
|
|
|
@abstractmethod |
|
def execute_query(self, query: str) -> Any: |
|
"""Abstract method to execute a query against the database.""" |
|
pass |
|
|
|
|
|
@lru_cache(maxsize=128) |
|
def execute_query_local(db_path: str, query: str) -> Any: |
|
"""Executes a query against the SQLite database.""" |
|
conn = None |
|
try: |
|
conn = sqlite3.connect(db_path) |
|
cursor = conn.cursor() |
|
cursor.execute(query) |
|
return cursor.fetchall() |
|
except sqlite3.Error as e: |
|
logger.info(f"Error executing SQL: {e}") |
|
return None |
|
finally: |
|
if conn: |
|
conn.close() |
|
|
|
|
|
class LocalSQLiteConnector(DatabaseConnector): |
|
"""Database connector for SQLite databases.""" |
|
|
|
def __init__(self, db_config: SQLDatabase): |
|
super().__init__(db_config) |
|
db_id = self.db_config.get("db_id") |
|
if not db_id: |
|
raise ValueError("db_id is required for SQLiteConnector.") |
|
self.db_path = self.get_db_file_path(db_id) |
|
self.conn: sqlite3.Connection = sqlite3.connect(self.db_path) |
|
self.cursor: sqlite3.Cursor = self.conn.cursor() |
|
|
|
def download_database(self, db_id): |
|
"""Downloads the database from huggingface if needed.""" |
|
done_file_path = os.path.join(self.databases_folder, "download_done") |
|
if "bird/" in db_id: |
|
if not os.path.exists(done_file_path): |
|
snapshot_download( |
|
repo_id="premai-io/birdbench", |
|
repo_type="dataset", |
|
local_dir=self.databases_folder, |
|
force_download=False, |
|
allow_patterns="*validation*", |
|
) |
|
open(os.path.join(self.databases_folder, "download_done"), "w").close() |
|
else: |
|
raise NotImplementedError( |
|
f"current local db: {db_id} is not supported, only bird" |
|
) |
|
|
|
def get_db_file_path(self, db_id): |
|
"""Gets the local path of a downloaded database file.""" |
|
self.download_database(db_id) |
|
db_id = db_id.split("/")[-1] |
|
|
|
db_file_pattern = os.path.join(self.databases_folder, "**", db_id + ".sqlite") |
|
db_file_paths = glob.glob(db_file_pattern, recursive=True) |
|
|
|
if not db_file_paths: |
|
raise FileNotFoundError(f"Database file {db_id} not found.") |
|
if len(db_file_paths) > 1: |
|
raise FileExistsError(f"More than one files matched for {db_id}") |
|
return db_file_paths[0] |
|
|
|
def get_table_schema( |
|
self, |
|
) -> str: |
|
"""Extracts schema from an SQLite database.""" |
|
self.cursor.execute("SELECT name FROM sqlite_master WHERE type='table'") |
|
tables: list[tuple[str]] = self.cursor.fetchall() |
|
schemas: dict[str, str] = {} |
|
|
|
for table in tables: |
|
if isinstance(table, tuple): |
|
table = table[0] |
|
if table == "sqlite_sequence": |
|
continue |
|
sql_query: str = ( |
|
f"SELECT sql FROM sqlite_master WHERE type='table' AND name='{table}';" |
|
) |
|
self.cursor.execute(sql_query) |
|
schema_prompt: str = self.cursor.fetchone()[0] |
|
|
|
schemas[table] = schema_prompt |
|
|
|
schema_prompt: str = "\n\n".join(list(schemas.values())) |
|
return schema_prompt |
|
|
|
def execute_query(self, query: str) -> Any: |
|
"""Executes a query against the SQLite database.""" |
|
return execute_query_local(self.db_path, query) |
|
|
|
|
|
class InMemoryDatabaseConnector(DatabaseConnector): |
|
"""Database connector for mocking databases with in-memory data structures.""" |
|
|
|
def __init__(self, db_config: SQLDatabase): |
|
super().__init__(db_config) |
|
self.tables = db_config.get("data", None) |
|
|
|
if not self.tables: |
|
raise ValueError("data is required for InMemoryDatabaseConnector.") |
|
|
|
def get_table_schema( |
|
self, |
|
select_tables: Optional[List[str]] = None, |
|
) -> str: |
|
"""Generates a mock schema from the tables structure.""" |
|
schemas = {} |
|
for table_name, table_data in self.tables.items(): |
|
if select_tables and table_name.lower() not in select_tables: |
|
continue |
|
columns = ", ".join([f"`{col}` TEXT" for col in table_data["columns"]]) |
|
schema = f"CREATE TABLE `{table_name}` ({columns});" |
|
|
|
schemas[table_name] = schema |
|
|
|
return "\n\n".join(list(schemas.values())) |
|
|
|
def execute_query(self, query: str) -> Any: |
|
"""Simulates executing a query against the mock database.""" |
|
|
|
conn = sqlite3.connect(":memory:") |
|
cursor = conn.cursor() |
|
logger.debug("Running SQL query over in-memory DB") |
|
|
|
|
|
for table_name, table_data in self.tables.items(): |
|
columns = table_data["columns"] |
|
rows = table_data["rows"] |
|
|
|
|
|
cursor.execute(f"CREATE TABLE {table_name} ({', '.join(columns)})") |
|
|
|
|
|
placeholders = ", ".join(["?"] * len(columns)) |
|
cursor.executemany( |
|
f"INSERT INTO {table_name} VALUES ({placeholders})", rows |
|
) |
|
|
|
try: |
|
cursor.execute(query) |
|
return cursor.fetchall() |
|
except sqlite3.Error as e: |
|
logger.info(f"Error executing SQL: {e}") |
|
return None |
|
finally: |
|
conn.close() |
|
|
|
|
|
@lru_cache(maxsize=128) |
|
def execute_query_remote( |
|
api_url: str, |
|
database_id: str, |
|
api_key: str, |
|
query: str, |
|
retryable_exceptions: tuple = (ConnectionError, ReadTimeout), |
|
max_retries: int = 3, |
|
retry_delay: int = 5, |
|
timeout: int = 30, |
|
) -> Optional[dict]: |
|
"""Executes a query against the remote database, with retries for certain exceptions.""" |
|
headers = { |
|
"Content-Type": "application/json", |
|
"accept": "application/json", |
|
"Authorization": f"Bearer {api_key}", |
|
} |
|
retries = 0 |
|
while retries <= max_retries: |
|
try: |
|
response = requests.post( |
|
f"{api_url}/sql", |
|
headers=headers, |
|
json={"sql": query, "dataSourceId": database_id}, |
|
verify=True, |
|
timeout=timeout, |
|
) |
|
response.raise_for_status() |
|
return response.json() |
|
|
|
except retryable_exceptions as e: |
|
retries += 1 |
|
logger.warning( |
|
f"Attempt {retries} failed with error: {e}. Retrying in {retry_delay} seconds." |
|
) |
|
if retries <= max_retries: |
|
time.sleep(retry_delay) |
|
else: |
|
logger.error(f"Max retries ({max_retries}) exceeded for query: {query}") |
|
return None |
|
|
|
except requests.exceptions.HTTPError as e: |
|
if e.response.status_code >= 500: |
|
retries += 1 |
|
logger.warning( |
|
f"Server error, attempt {retries} failed with error: {e}. Retrying in {retry_delay} seconds." |
|
) |
|
if retries <= max_retries: |
|
time.sleep(retry_delay) |
|
else: |
|
logger.error( |
|
f"Max retries ({max_retries}) exceeded for query: {query}" |
|
) |
|
return None |
|
else: |
|
logger.error(f"HTTP Error on attempt {retries}: {e}") |
|
return None |
|
|
|
except Exception as e: |
|
logger.error(f"Unexpected error on attempt {retries}: {e}") |
|
return None |
|
|
|
return None |
|
|
|
|
|
class RemoteDatabaseConnector(DatabaseConnector): |
|
"""Database connector for remote databases accessed via HTTP.""" |
|
|
|
def __init__(self, db_config: SQLDatabase): |
|
super().__init__(db_config) |
|
|
|
assert db_config[ |
|
"db_id" |
|
], "db_id must be in db_config for RemoteDatabaseConnector" |
|
self.api_url, self.database_id = ( |
|
db_config["db_id"].split(",")[0], |
|
db_config["db_id"].split("db_id=")[-1].split(",")[0], |
|
) |
|
|
|
if not self.api_url or not self.database_id: |
|
raise ValueError( |
|
"Both 'api_url' and 'database_id' are required for RemoteDatabaseConnector." |
|
) |
|
|
|
self.api_key = os.getenv("SQL_API_KEY", None) |
|
if not self.api_key: |
|
raise ValueError( |
|
"The environment variable 'SQL_API_KEY' must be set to use the RemoteDatabaseConnector." |
|
) |
|
|
|
self.headers = { |
|
"Content-Type": "application/json", |
|
"accept": "application/json", |
|
"Authorization": f"Bearer {self.api_key}", |
|
} |
|
|
|
self.timeout = 30 |
|
|
|
def get_table_schema( |
|
self, |
|
) -> str: |
|
"""Retrieves the schema of a database.""" |
|
cur_api_url = f"{self.api_url}/datasource/{self.database_id}" |
|
response = requests.get( |
|
cur_api_url, |
|
headers=self.headers, |
|
verify=True, |
|
timeout=self.timeout, |
|
) |
|
if response.status_code == 200: |
|
schema = response.json()["schema"] |
|
else: |
|
raise OSError(f"Could not fetch schema from {cur_api_url}") |
|
|
|
schema_text = "" |
|
for table in schema["tables"]: |
|
schema_text += f"Table: {table['table_name']} has columns: {[col['column_name'] for col in table['columns']]}\n" |
|
|
|
return schema_text |
|
|
|
def execute_query(self, query: str) -> Any: |
|
"""Executes a query against the remote database, with retries for certain exceptions.""" |
|
return execute_query_remote( |
|
api_url=self.api_url, |
|
database_id=self.database_id, |
|
api_key=self.api_key, |
|
query=query, |
|
timeout=self.timeout, |
|
) |
|
|
|
|
|
def get_db_connector(db_type: str): |
|
"""Creates and returns the appropriate DatabaseConnector instance based on db_type.""" |
|
if db_type == "local": |
|
connector = LocalSQLiteConnector |
|
elif db_type == "in_memory": |
|
connector = InMemoryDatabaseConnector |
|
elif db_type == "remote": |
|
connector = RemoteDatabaseConnector |
|
|
|
else: |
|
raise ValueError(f"Unsupported database type: {db_type}") |
|
|
|
return connector |
|
|