Spaces:
Runtime error
Runtime error
| from __future__ import annotations | |
| from typing import TYPE_CHECKING, Any, Iterable, List, Optional | |
| if TYPE_CHECKING: | |
| from pyspark.sql import DataFrame, Row, SparkSession | |
| class SparkSQL: | |
| """SparkSQL is a utility class for interacting with Spark SQL.""" | |
| def __init__( | |
| self, | |
| spark_session: Optional[SparkSession] = None, | |
| catalog: Optional[str] = None, | |
| schema: Optional[str] = None, | |
| ignore_tables: Optional[List[str]] = None, | |
| include_tables: Optional[List[str]] = None, | |
| sample_rows_in_table_info: int = 3, | |
| ): | |
| """Initialize a SparkSQL object. | |
| Args: | |
| spark_session: A SparkSession object. | |
| If not provided, one will be created. | |
| catalog: The catalog to use. | |
| If not provided, the default catalog will be used. | |
| schema: The schema to use. | |
| If not provided, the default schema will be used. | |
| ignore_tables: A list of tables to ignore. | |
| If not provided, all tables will be used. | |
| include_tables: A list of tables to include. | |
| If not provided, all tables will be used. | |
| sample_rows_in_table_info: The number of rows to include in the table info. | |
| Defaults to 3. | |
| """ | |
| try: | |
| from pyspark.sql import SparkSession | |
| except ImportError: | |
| raise ImportError( | |
| "pyspark is not installed. Please install it with `pip install pyspark`" | |
| ) | |
| self._spark = ( | |
| spark_session if spark_session else SparkSession.builder.getOrCreate() | |
| ) | |
| if catalog is not None: | |
| self._spark.catalog.setCurrentCatalog(catalog) | |
| if schema is not None: | |
| self._spark.catalog.setCurrentDatabase(schema) | |
| self._all_tables = set(self._get_all_table_names()) | |
| self._include_tables = set(include_tables) if include_tables else set() | |
| if self._include_tables: | |
| missing_tables = self._include_tables - self._all_tables | |
| if missing_tables: | |
| raise ValueError( | |
| f"include_tables {missing_tables} not found in database" | |
| ) | |
| self._ignore_tables = set(ignore_tables) if ignore_tables else set() | |
| if self._ignore_tables: | |
| missing_tables = self._ignore_tables - self._all_tables | |
| if missing_tables: | |
| raise ValueError( | |
| f"ignore_tables {missing_tables} not found in database" | |
| ) | |
| usable_tables = self.get_usable_table_names() | |
| self._usable_tables = set(usable_tables) if usable_tables else self._all_tables | |
| if not isinstance(sample_rows_in_table_info, int): | |
| raise TypeError("sample_rows_in_table_info must be an integer") | |
| self._sample_rows_in_table_info = sample_rows_in_table_info | |
| def from_uri( | |
| cls, database_uri: str, engine_args: Optional[dict] = None, **kwargs: Any | |
| ) -> SparkSQL: | |
| """Creating a remote Spark Session via Spark connect. | |
| For example: SparkSQL.from_uri("sc://localhost:15002") | |
| """ | |
| try: | |
| from pyspark.sql import SparkSession | |
| except ImportError: | |
| raise ValueError( | |
| "pyspark is not installed. Please install it with `pip install pyspark`" | |
| ) | |
| spark = SparkSession.builder.remote(database_uri).getOrCreate() | |
| return cls(spark, **kwargs) | |
| def get_usable_table_names(self) -> Iterable[str]: | |
| """Get names of tables available.""" | |
| if self._include_tables: | |
| return self._include_tables | |
| # sorting the result can help LLM understanding it. | |
| return sorted(self._all_tables - self._ignore_tables) | |
| def _get_all_table_names(self) -> Iterable[str]: | |
| rows = self._spark.sql("SHOW TABLES").select("tableName").collect() | |
| return list(map(lambda row: row.tableName, rows)) | |
| def _get_create_table_stmt(self, table: str) -> str: | |
| statement = ( | |
| self._spark.sql(f"SHOW CREATE TABLE {table}").collect()[0].createtab_stmt | |
| ) | |
| # Ignore the data source provider and options to reduce the number of tokens. | |
| using_clause_index = statement.find("USING") | |
| return statement[:using_clause_index] + ";" | |
| def get_table_info(self, table_names: Optional[List[str]] = None) -> str: | |
| all_table_names = self.get_usable_table_names() | |
| if table_names is not None: | |
| missing_tables = set(table_names).difference(all_table_names) | |
| if missing_tables: | |
| raise ValueError(f"table_names {missing_tables} not found in database") | |
| all_table_names = table_names | |
| tables = [] | |
| for table_name in all_table_names: | |
| table_info = self._get_create_table_stmt(table_name) | |
| if self._sample_rows_in_table_info: | |
| table_info += "\n\n/*" | |
| table_info += f"\n{self._get_sample_spark_rows(table_name)}\n" | |
| table_info += "*/" | |
| tables.append(table_info) | |
| final_str = "\n\n".join(tables) | |
| return final_str | |
| def _get_sample_spark_rows(self, table: str) -> str: | |
| query = f"SELECT * FROM {table} LIMIT {self._sample_rows_in_table_info}" | |
| df = self._spark.sql(query) | |
| columns_str = "\t".join(list(map(lambda f: f.name, df.schema.fields))) | |
| try: | |
| sample_rows = self._get_dataframe_results(df) | |
| # save the sample rows in string format | |
| sample_rows_str = "\n".join(["\t".join(row) for row in sample_rows]) | |
| except Exception: | |
| sample_rows_str = "" | |
| return ( | |
| f"{self._sample_rows_in_table_info} rows from {table} table:\n" | |
| f"{columns_str}\n" | |
| f"{sample_rows_str}" | |
| ) | |
| def _convert_row_as_tuple(self, row: Row) -> tuple: | |
| return tuple(map(str, row.asDict().values())) | |
| def _get_dataframe_results(self, df: DataFrame) -> list: | |
| return list(map(self._convert_row_as_tuple, df.collect())) | |
| def run(self, command: str, fetch: str = "all") -> str: | |
| df = self._spark.sql(command) | |
| if fetch == "one": | |
| df = df.limit(1) | |
| return str(self._get_dataframe_results(df)) | |
| def get_table_info_no_throw(self, table_names: Optional[List[str]] = None) -> str: | |
| """Get information about specified tables. | |
| Follows best practices as specified in: Rajkumar et al, 2022 | |
| (https://arxiv.org/abs/2204.00498) | |
| If `sample_rows_in_table_info`, the specified number of sample rows will be | |
| appended to each table description. This can increase performance as | |
| demonstrated in the paper. | |
| """ | |
| try: | |
| return self.get_table_info(table_names) | |
| except ValueError as e: | |
| """Format the error message""" | |
| return f"Error: {e}" | |
| def run_no_throw(self, command: str, fetch: str = "all") -> str: | |
| """Execute a SQL command and return a string representing the results. | |
| If the statement returns rows, a string of the results is returned. | |
| If the statement returns no rows, an empty string is returned. | |
| If the statement throws an error, the error message is returned. | |
| """ | |
| try: | |
| return self.run(command, fetch) | |
| except Exception as e: | |
| """Format the error message""" | |
| return f"Error: {e}" | |