Spaces:
Runtime error
Runtime error
File size: 7,519 Bytes
129cd69 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 |
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Iterable, List, Optional
if TYPE_CHECKING:
from pyspark.sql import DataFrame, Row, SparkSession
class SparkSQL:
"""SparkSQL is a utility class for interacting with Spark SQL."""
def __init__(
self,
spark_session: Optional[SparkSession] = None,
catalog: Optional[str] = None,
schema: Optional[str] = None,
ignore_tables: Optional[List[str]] = None,
include_tables: Optional[List[str]] = None,
sample_rows_in_table_info: int = 3,
):
"""Initialize a SparkSQL object.
Args:
spark_session: A SparkSession object.
If not provided, one will be created.
catalog: The catalog to use.
If not provided, the default catalog will be used.
schema: The schema to use.
If not provided, the default schema will be used.
ignore_tables: A list of tables to ignore.
If not provided, all tables will be used.
include_tables: A list of tables to include.
If not provided, all tables will be used.
sample_rows_in_table_info: The number of rows to include in the table info.
Defaults to 3.
"""
try:
from pyspark.sql import SparkSession
except ImportError:
raise ImportError(
"pyspark is not installed. Please install it with `pip install pyspark`"
)
self._spark = (
spark_session if spark_session else SparkSession.builder.getOrCreate()
)
if catalog is not None:
self._spark.catalog.setCurrentCatalog(catalog)
if schema is not None:
self._spark.catalog.setCurrentDatabase(schema)
self._all_tables = set(self._get_all_table_names())
self._include_tables = set(include_tables) if include_tables else set()
if self._include_tables:
missing_tables = self._include_tables - self._all_tables
if missing_tables:
raise ValueError(
f"include_tables {missing_tables} not found in database"
)
self._ignore_tables = set(ignore_tables) if ignore_tables else set()
if self._ignore_tables:
missing_tables = self._ignore_tables - self._all_tables
if missing_tables:
raise ValueError(
f"ignore_tables {missing_tables} not found in database"
)
usable_tables = self.get_usable_table_names()
self._usable_tables = set(usable_tables) if usable_tables else self._all_tables
if not isinstance(sample_rows_in_table_info, int):
raise TypeError("sample_rows_in_table_info must be an integer")
self._sample_rows_in_table_info = sample_rows_in_table_info
@classmethod
def from_uri(
cls, database_uri: str, engine_args: Optional[dict] = None, **kwargs: Any
) -> SparkSQL:
"""Creating a remote Spark Session via Spark connect.
For example: SparkSQL.from_uri("sc://localhost:15002")
"""
try:
from pyspark.sql import SparkSession
except ImportError:
raise ValueError(
"pyspark is not installed. Please install it with `pip install pyspark`"
)
spark = SparkSession.builder.remote(database_uri).getOrCreate()
return cls(spark, **kwargs)
def get_usable_table_names(self) -> Iterable[str]:
"""Get names of tables available."""
if self._include_tables:
return self._include_tables
# sorting the result can help LLM understanding it.
return sorted(self._all_tables - self._ignore_tables)
def _get_all_table_names(self) -> Iterable[str]:
rows = self._spark.sql("SHOW TABLES").select("tableName").collect()
return list(map(lambda row: row.tableName, rows))
def _get_create_table_stmt(self, table: str) -> str:
statement = (
self._spark.sql(f"SHOW CREATE TABLE {table}").collect()[0].createtab_stmt
)
# Ignore the data source provider and options to reduce the number of tokens.
using_clause_index = statement.find("USING")
return statement[:using_clause_index] + ";"
def get_table_info(self, table_names: Optional[List[str]] = None) -> str:
all_table_names = self.get_usable_table_names()
if table_names is not None:
missing_tables = set(table_names).difference(all_table_names)
if missing_tables:
raise ValueError(f"table_names {missing_tables} not found in database")
all_table_names = table_names
tables = []
for table_name in all_table_names:
table_info = self._get_create_table_stmt(table_name)
if self._sample_rows_in_table_info:
table_info += "\n\n/*"
table_info += f"\n{self._get_sample_spark_rows(table_name)}\n"
table_info += "*/"
tables.append(table_info)
final_str = "\n\n".join(tables)
return final_str
def _get_sample_spark_rows(self, table: str) -> str:
query = f"SELECT * FROM {table} LIMIT {self._sample_rows_in_table_info}"
df = self._spark.sql(query)
columns_str = "\t".join(list(map(lambda f: f.name, df.schema.fields)))
try:
sample_rows = self._get_dataframe_results(df)
# save the sample rows in string format
sample_rows_str = "\n".join(["\t".join(row) for row in sample_rows])
except Exception:
sample_rows_str = ""
return (
f"{self._sample_rows_in_table_info} rows from {table} table:\n"
f"{columns_str}\n"
f"{sample_rows_str}"
)
def _convert_row_as_tuple(self, row: Row) -> tuple:
return tuple(map(str, row.asDict().values()))
def _get_dataframe_results(self, df: DataFrame) -> list:
return list(map(self._convert_row_as_tuple, df.collect()))
def run(self, command: str, fetch: str = "all") -> str:
df = self._spark.sql(command)
if fetch == "one":
df = df.limit(1)
return str(self._get_dataframe_results(df))
def get_table_info_no_throw(self, table_names: Optional[List[str]] = None) -> str:
"""Get information about specified tables.
Follows best practices as specified in: Rajkumar et al, 2022
(https://arxiv.org/abs/2204.00498)
If `sample_rows_in_table_info`, the specified number of sample rows will be
appended to each table description. This can increase performance as
demonstrated in the paper.
"""
try:
return self.get_table_info(table_names)
except ValueError as e:
"""Format the error message"""
return f"Error: {e}"
def run_no_throw(self, command: str, fetch: str = "all") -> str:
"""Execute a SQL command and return a string representing the results.
If the statement returns rows, a string of the results is returned.
If the statement returns no rows, an empty string is returned.
If the statement throws an error, the error message is returned.
"""
try:
return self.run(command, fetch)
except Exception as e:
"""Format the error message"""
return f"Error: {e}"
|