Spaces:
Runtime error
Runtime error
"""SQLAlchemy wrapper around a database.""" | |
from __future__ import annotations | |
import warnings | |
from typing import Any, Dict, Iterable, List, Literal, Optional, Sequence, Union | |
import sqlalchemy | |
from sqlalchemy import MetaData, Table, create_engine, inspect, select, text | |
from sqlalchemy.engine import Engine | |
from sqlalchemy.exc import ProgrammingError, SQLAlchemyError | |
from sqlalchemy.schema import CreateTable | |
from sqlalchemy.types import NullType | |
from langchain.utils import get_from_env | |
def _format_index(index: sqlalchemy.engine.interfaces.ReflectedIndex) -> str: | |
return ( | |
f'Name: {index["name"]}, Unique: {index["unique"]},' | |
f' Columns: {str(index["column_names"])}' | |
) | |
def truncate_word(content: Any, *, length: int, suffix: str = "...") -> str: | |
""" | |
Truncate a string to a certain number of words, based on the max string | |
length. | |
""" | |
if not isinstance(content, str) or length <= 0: | |
return content | |
if len(content) <= length: | |
return content | |
return content[: length - len(suffix)].rsplit(" ", 1)[0] + suffix | |
class SQLDatabase: | |
"""SQLAlchemy wrapper around a database.""" | |
def __init__( | |
self, | |
engine: Engine, | |
schema: Optional[str] = None, | |
metadata: Optional[MetaData] = None, | |
ignore_tables: Optional[List[str]] = None, | |
include_tables: Optional[List[str]] = None, | |
sample_rows_in_table_info: int = 3, | |
indexes_in_table_info: bool = False, | |
custom_table_info: Optional[dict] = None, | |
view_support: bool = False, | |
max_string_length: int = 300, | |
): | |
"""Create engine from database URI.""" | |
self._engine = engine | |
self._schema = schema | |
if include_tables and ignore_tables: | |
raise ValueError("Cannot specify both include_tables and ignore_tables") | |
self._inspector = inspect(self._engine) | |
# including view support by adding the views as well as tables to the all | |
# tables list if view_support is True | |
self._all_tables = set( | |
self._inspector.get_table_names(schema=schema) | |
+ (self._inspector.get_view_names(schema=schema) if view_support else []) | |
) | |
self._include_tables = set(include_tables) if include_tables else set() | |
if self._include_tables: | |
missing_tables = self._include_tables - self._all_tables | |
if missing_tables: | |
raise ValueError( | |
f"include_tables {missing_tables} not found in database" | |
) | |
self._ignore_tables = set(ignore_tables) if ignore_tables else set() | |
if self._ignore_tables: | |
missing_tables = self._ignore_tables - self._all_tables | |
if missing_tables: | |
raise ValueError( | |
f"ignore_tables {missing_tables} not found in database" | |
) | |
usable_tables = self.get_usable_table_names() | |
self._usable_tables = set(usable_tables) if usable_tables else self._all_tables | |
if not isinstance(sample_rows_in_table_info, int): | |
raise TypeError("sample_rows_in_table_info must be an integer") | |
self._sample_rows_in_table_info = sample_rows_in_table_info | |
self._indexes_in_table_info = indexes_in_table_info | |
self._custom_table_info = custom_table_info | |
if self._custom_table_info: | |
if not isinstance(self._custom_table_info, dict): | |
raise TypeError( | |
"table_info must be a dictionary with table names as keys and the " | |
"desired table info as values" | |
) | |
# only keep the tables that are also present in the database | |
intersection = set(self._custom_table_info).intersection(self._all_tables) | |
self._custom_table_info = dict( | |
(table, self._custom_table_info[table]) | |
for table in self._custom_table_info | |
if table in intersection | |
) | |
self._max_string_length = max_string_length | |
self._metadata = metadata or MetaData() | |
# including view support if view_support = true | |
self._metadata.reflect( | |
views=view_support, | |
bind=self._engine, | |
only=list(self._usable_tables), | |
schema=self._schema, | |
) | |
def from_uri( | |
cls, database_uri: str, engine_args: Optional[dict] = None, **kwargs: Any | |
) -> SQLDatabase: | |
"""Construct a SQLAlchemy engine from URI.""" | |
_engine_args = engine_args or {} | |
return cls(create_engine(database_uri, **_engine_args), **kwargs) | |
def from_databricks( | |
cls, | |
catalog: str, | |
schema: str, | |
host: Optional[str] = None, | |
api_token: Optional[str] = None, | |
warehouse_id: Optional[str] = None, | |
cluster_id: Optional[str] = None, | |
engine_args: Optional[dict] = None, | |
**kwargs: Any, | |
) -> SQLDatabase: | |
""" | |
Class method to create an SQLDatabase instance from a Databricks connection. | |
This method requires the 'databricks-sql-connector' package. If not installed, | |
it can be added using `pip install databricks-sql-connector`. | |
Args: | |
catalog (str): The catalog name in the Databricks database. | |
schema (str): The schema name in the catalog. | |
host (Optional[str]): The Databricks workspace hostname, excluding | |
'https://' part. If not provided, it attempts to fetch from the | |
environment variable 'DATABRICKS_HOST'. If still unavailable and if | |
running in a Databricks notebook, it defaults to the current workspace | |
hostname. Defaults to None. | |
api_token (Optional[str]): The Databricks personal access token for | |
accessing the Databricks SQL warehouse or the cluster. If not provided, | |
it attempts to fetch from 'DATABRICKS_TOKEN'. If still unavailable | |
and running in a Databricks notebook, a temporary token for the current | |
user is generated. Defaults to None. | |
warehouse_id (Optional[str]): The warehouse ID in the Databricks SQL. If | |
provided, the method configures the connection to use this warehouse. | |
Cannot be used with 'cluster_id'. Defaults to None. | |
cluster_id (Optional[str]): The cluster ID in the Databricks Runtime. If | |
provided, the method configures the connection to use this cluster. | |
Cannot be used with 'warehouse_id'. If running in a Databricks notebook | |
and both 'warehouse_id' and 'cluster_id' are None, it uses the ID of the | |
cluster the notebook is attached to. Defaults to None. | |
engine_args (Optional[dict]): The arguments to be used when connecting | |
Databricks. Defaults to None. | |
**kwargs (Any): Additional keyword arguments for the `from_uri` method. | |
Returns: | |
SQLDatabase: An instance of SQLDatabase configured with the provided | |
Databricks connection details. | |
Raises: | |
ValueError: If 'databricks-sql-connector' is not found, or if both | |
'warehouse_id' and 'cluster_id' are provided, or if neither | |
'warehouse_id' nor 'cluster_id' are provided and it's not executing | |
inside a Databricks notebook. | |
""" | |
try: | |
from databricks import sql # noqa: F401 | |
except ImportError: | |
raise ValueError( | |
"databricks-sql-connector package not found, please install with" | |
" `pip install databricks-sql-connector`" | |
) | |
context = None | |
try: | |
from dbruntime.databricks_repl_context import get_context | |
context = get_context() | |
except ImportError: | |
pass | |
default_host = context.browserHostName if context else None | |
if host is None: | |
host = get_from_env("host", "DATABRICKS_HOST", default_host) | |
default_api_token = context.apiToken if context else None | |
if api_token is None: | |
api_token = get_from_env("api_token", "DATABRICKS_TOKEN", default_api_token) | |
if warehouse_id is None and cluster_id is None: | |
if context: | |
cluster_id = context.clusterId | |
else: | |
raise ValueError( | |
"Need to provide either 'warehouse_id' or 'cluster_id'." | |
) | |
if warehouse_id and cluster_id: | |
raise ValueError("Can't have both 'warehouse_id' or 'cluster_id'.") | |
if warehouse_id: | |
http_path = f"/sql/1.0/warehouses/{warehouse_id}" | |
else: | |
http_path = f"/sql/protocolv1/o/0/{cluster_id}" | |
uri = ( | |
f"databricks://token:{api_token}@{host}?" | |
f"http_path={http_path}&catalog={catalog}&schema={schema}" | |
) | |
return cls.from_uri(database_uri=uri, engine_args=engine_args, **kwargs) | |
def from_cnosdb( | |
cls, | |
url: str = "127.0.0.1:8902", | |
user: str = "root", | |
password: str = "", | |
tenant: str = "cnosdb", | |
database: str = "public", | |
) -> SQLDatabase: | |
""" | |
Class method to create an SQLDatabase instance from a CnosDB connection. | |
This method requires the 'cnos-connector' package. If not installed, it | |
can be added using `pip install cnos-connector`. | |
Args: | |
url (str): The HTTP connection host name and port number of the CnosDB | |
service, excluding "http://" or "https://", with a default value | |
of "127.0.0.1:8902". | |
user (str): The username used to connect to the CnosDB service, with a | |
default value of "root". | |
password (str): The password of the user connecting to the CnosDB service, | |
with a default value of "". | |
tenant (str): The name of the tenant used to connect to the CnosDB service, | |
with a default value of "cnosdb". | |
database (str): The name of the database in the CnosDB tenant. | |
Returns: | |
SQLDatabase: An instance of SQLDatabase configured with the provided | |
CnosDB connection details. | |
""" | |
try: | |
from cnosdb_connector import make_cnosdb_langchain_uri | |
uri = make_cnosdb_langchain_uri(url, user, password, tenant, database) | |
return cls.from_uri(database_uri=uri) | |
except ImportError: | |
raise ValueError( | |
"cnos-connector package not found, please install with" | |
" `pip install cnos-connector`" | |
) | |
def dialect(self) -> str: | |
"""Return string representation of dialect to use.""" | |
return self._engine.dialect.name | |
def get_usable_table_names(self) -> Iterable[str]: | |
"""Get names of tables available.""" | |
if self._include_tables: | |
return sorted(self._include_tables) | |
return sorted(self._all_tables - self._ignore_tables) | |
def get_table_names(self) -> Iterable[str]: | |
"""Get names of tables available.""" | |
warnings.warn( | |
"This method is deprecated - please use `get_usable_table_names`." | |
) | |
return self.get_usable_table_names() | |
def table_info(self) -> str: | |
"""Information about all tables in the database.""" | |
return self.get_table_info() | |
def get_table_info(self, table_names: Optional[List[str]] = None) -> str: | |
"""Get information about specified tables. | |
Follows best practices as specified in: Rajkumar et al, 2022 | |
(https://arxiv.org/abs/2204.00498) | |
If `sample_rows_in_table_info`, the specified number of sample rows will be | |
appended to each table description. This can increase performance as | |
demonstrated in the paper. | |
""" | |
all_table_names = self.get_usable_table_names() | |
if table_names is not None: | |
missing_tables = set(table_names).difference(all_table_names) | |
if missing_tables: | |
raise ValueError(f"table_names {missing_tables} not found in database") | |
all_table_names = table_names | |
meta_tables = [ | |
tbl | |
for tbl in self._metadata.sorted_tables | |
if tbl.name in set(all_table_names) | |
and not (self.dialect == "sqlite" and tbl.name.startswith("sqlite_")) | |
] | |
tables = [] | |
for table in meta_tables: | |
if self._custom_table_info and table.name in self._custom_table_info: | |
tables.append(self._custom_table_info[table.name]) | |
continue | |
# Ignore JSON datatyped columns | |
for k, v in table.columns.items(): | |
if type(v.type) is NullType: | |
table._columns.remove(v) | |
# add create table command | |
create_table = str(CreateTable(table).compile(self._engine)) | |
table_info = f"{create_table.rstrip()}" | |
has_extra_info = ( | |
self._indexes_in_table_info or self._sample_rows_in_table_info | |
) | |
if has_extra_info: | |
table_info += "\n\n/*" | |
if self._indexes_in_table_info: | |
table_info += f"\n{self._get_table_indexes(table)}\n" | |
if self._sample_rows_in_table_info: | |
table_info += f"\n{self._get_sample_rows(table)}\n" | |
if has_extra_info: | |
table_info += "*/" | |
tables.append(table_info) | |
tables.sort() | |
final_str = "\n\n".join(tables) | |
return final_str | |
def _get_table_indexes(self, table: Table) -> str: | |
indexes = self._inspector.get_indexes(table.name) | |
indexes_formatted = "\n".join(map(_format_index, indexes)) | |
return f"Table Indexes:\n{indexes_formatted}" | |
def _get_sample_rows(self, table: Table) -> str: | |
# build the select command | |
command = select(table).limit(self._sample_rows_in_table_info) | |
# save the columns in string format | |
columns_str = "\t".join([col.name for col in table.columns]) | |
try: | |
# get the sample rows | |
with self._engine.connect() as connection: | |
sample_rows_result = connection.execute(command) # type: ignore | |
# shorten values in the sample rows | |
sample_rows = list( | |
map(lambda ls: [str(i)[:100] for i in ls], sample_rows_result) | |
) | |
# save the sample rows in string format | |
sample_rows_str = "\n".join(["\t".join(row) for row in sample_rows]) | |
# in some dialects when there are no rows in the table a | |
# 'ProgrammingError' is returned | |
except ProgrammingError: | |
sample_rows_str = "" | |
return ( | |
f"{self._sample_rows_in_table_info} rows from {table.name} table:\n" | |
f"{columns_str}\n" | |
f"{sample_rows_str}" | |
) | |
def _execute( | |
self, | |
command: str, | |
fetch: Union[Literal["all"], Literal["one"]] = "all", | |
) -> Sequence[Dict[str, Any]]: | |
""" | |
Executes SQL command through underlying engine. | |
If the statement returns no rows, an empty list is returned. | |
""" | |
with self._engine.begin() as connection: | |
if self._schema is not None: | |
if self.dialect == "snowflake": | |
connection.exec_driver_sql( | |
"ALTER SESSION SET search_path = %s", (self._schema,) | |
) | |
elif self.dialect == "bigquery": | |
connection.exec_driver_sql("SET @@dataset_id=?", (self._schema,)) | |
elif self.dialect == "mssql": | |
pass | |
elif self.dialect == "trino": | |
connection.exec_driver_sql("USE ?", (self._schema,)) | |
elif self.dialect == "duckdb": | |
# Unclear which parameterized argument syntax duckdb supports. | |
# The docs for the duckdb client say they support multiple, | |
# but `duckdb_engine` seemed to struggle with all of them: | |
# https://github.com/Mause/duckdb_engine/issues/796 | |
connection.exec_driver_sql(f"SET search_path TO {self._schema}") | |
elif self.dialect == "oracle": | |
connection.exec_driver_sql( | |
f"ALTER SESSION SET CURRENT_SCHEMA = {self._schema}" | |
) | |
else: # postgresql and other compatible dialects | |
connection.exec_driver_sql("SET search_path TO %s", (self._schema,)) | |
cursor = connection.execute(text(command)) | |
if cursor.returns_rows: | |
if fetch == "all": | |
result = [x._asdict() for x in cursor.fetchall()] | |
elif fetch == "one": | |
first_result = cursor.fetchone() | |
result = [] if first_result is None else [first_result._asdict()] | |
else: | |
raise ValueError("Fetch parameter must be either 'one' or 'all'") | |
return result | |
return [] | |
def run( | |
self, | |
command: str, | |
fetch: Union[Literal["all"], Literal["one"]] = "all", | |
) -> str: | |
"""Execute a SQL command and return a string representing the results. | |
If the statement returns rows, a string of the results is returned. | |
If the statement returns no rows, an empty string is returned. | |
""" | |
result = self._execute(command, fetch) | |
# Convert columns values to string to avoid issues with sqlalchemy | |
# truncating text | |
res = [ | |
tuple(truncate_word(c, length=self._max_string_length) for c in r.values()) | |
for r in result | |
] | |
if not res: | |
return "" | |
else: | |
return str(res) | |
def get_table_info_no_throw(self, table_names: Optional[List[str]] = None) -> str: | |
"""Get information about specified tables. | |
Follows best practices as specified in: Rajkumar et al, 2022 | |
(https://arxiv.org/abs/2204.00498) | |
If `sample_rows_in_table_info`, the specified number of sample rows will be | |
appended to each table description. This can increase performance as | |
demonstrated in the paper. | |
""" | |
try: | |
return self.get_table_info(table_names) | |
except ValueError as e: | |
"""Format the error message""" | |
return f"Error: {e}" | |
def run_no_throw( | |
self, | |
command: str, | |
fetch: Union[Literal["all"], Literal["one"]] = "all", | |
) -> str: | |
"""Execute a SQL command and return a string representing the results. | |
If the statement returns rows, a string of the results is returned. | |
If the statement returns no rows, an empty string is returned. | |
If the statement throws an error, the error message is returned. | |
""" | |
try: | |
return self.run(command, fetch) | |
except SQLAlchemyError as e: | |
"""Format the error message""" | |
return f"Error: {e}" | |