File size: 7,519 Bytes
129cd69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
from __future__ import annotations

from typing import TYPE_CHECKING, Any, Iterable, List, Optional

if TYPE_CHECKING:
    from pyspark.sql import DataFrame, Row, SparkSession


class SparkSQL:
    """SparkSQL is a utility class for interacting with Spark SQL."""

    def __init__(
        self,
        spark_session: Optional[SparkSession] = None,
        catalog: Optional[str] = None,
        schema: Optional[str] = None,
        ignore_tables: Optional[List[str]] = None,
        include_tables: Optional[List[str]] = None,
        sample_rows_in_table_info: int = 3,
    ):
        """Initialize a SparkSQL object.

        Args:
            spark_session: A SparkSession object.
              If not provided, one will be created.
            catalog: The catalog to use.
              If not provided, the default catalog will be used.
            schema: The schema to use.
              If not provided, the default schema will be used.
            ignore_tables: A list of tables to ignore.
              If not provided, all tables will be used.
            include_tables: A list of tables to include.
              If not provided, all tables will be used.
            sample_rows_in_table_info: The number of rows to include in the table info.
              Defaults to 3.
        """
        try:
            from pyspark.sql import SparkSession
        except ImportError:
            raise ImportError(
                "pyspark is not installed. Please install it with `pip install pyspark`"
            )

        self._spark = (
            spark_session if spark_session else SparkSession.builder.getOrCreate()
        )
        if catalog is not None:
            self._spark.catalog.setCurrentCatalog(catalog)
        if schema is not None:
            self._spark.catalog.setCurrentDatabase(schema)

        self._all_tables = set(self._get_all_table_names())
        self._include_tables = set(include_tables) if include_tables else set()
        if self._include_tables:
            missing_tables = self._include_tables - self._all_tables
            if missing_tables:
                raise ValueError(
                    f"include_tables {missing_tables} not found in database"
                )
        self._ignore_tables = set(ignore_tables) if ignore_tables else set()
        if self._ignore_tables:
            missing_tables = self._ignore_tables - self._all_tables
            if missing_tables:
                raise ValueError(
                    f"ignore_tables {missing_tables} not found in database"
                )
        usable_tables = self.get_usable_table_names()
        self._usable_tables = set(usable_tables) if usable_tables else self._all_tables

        if not isinstance(sample_rows_in_table_info, int):
            raise TypeError("sample_rows_in_table_info must be an integer")

        self._sample_rows_in_table_info = sample_rows_in_table_info

    @classmethod
    def from_uri(
        cls, database_uri: str, engine_args: Optional[dict] = None, **kwargs: Any
    ) -> SparkSQL:
        """Creating a remote Spark Session via Spark connect.
        For example: SparkSQL.from_uri("sc://localhost:15002")
        """
        try:
            from pyspark.sql import SparkSession
        except ImportError:
            raise ValueError(
                "pyspark is not installed. Please install it with `pip install pyspark`"
            )

        spark = SparkSession.builder.remote(database_uri).getOrCreate()
        return cls(spark, **kwargs)

    def get_usable_table_names(self) -> Iterable[str]:
        """Get names of tables available."""
        if self._include_tables:
            return self._include_tables
        # sorting the result can help LLM understanding it.
        return sorted(self._all_tables - self._ignore_tables)

    def _get_all_table_names(self) -> Iterable[str]:
        rows = self._spark.sql("SHOW TABLES").select("tableName").collect()
        return list(map(lambda row: row.tableName, rows))

    def _get_create_table_stmt(self, table: str) -> str:
        statement = (
            self._spark.sql(f"SHOW CREATE TABLE {table}").collect()[0].createtab_stmt
        )
        # Ignore the data source provider and options to reduce the number of tokens.
        using_clause_index = statement.find("USING")
        return statement[:using_clause_index] + ";"

    def get_table_info(self, table_names: Optional[List[str]] = None) -> str:
        all_table_names = self.get_usable_table_names()
        if table_names is not None:
            missing_tables = set(table_names).difference(all_table_names)
            if missing_tables:
                raise ValueError(f"table_names {missing_tables} not found in database")
            all_table_names = table_names
        tables = []
        for table_name in all_table_names:
            table_info = self._get_create_table_stmt(table_name)
            if self._sample_rows_in_table_info:
                table_info += "\n\n/*"
                table_info += f"\n{self._get_sample_spark_rows(table_name)}\n"
                table_info += "*/"
            tables.append(table_info)
        final_str = "\n\n".join(tables)
        return final_str

    def _get_sample_spark_rows(self, table: str) -> str:
        query = f"SELECT * FROM {table} LIMIT {self._sample_rows_in_table_info}"
        df = self._spark.sql(query)
        columns_str = "\t".join(list(map(lambda f: f.name, df.schema.fields)))
        try:
            sample_rows = self._get_dataframe_results(df)
            # save the sample rows in string format
            sample_rows_str = "\n".join(["\t".join(row) for row in sample_rows])
        except Exception:
            sample_rows_str = ""

        return (
            f"{self._sample_rows_in_table_info} rows from {table} table:\n"
            f"{columns_str}\n"
            f"{sample_rows_str}"
        )

    def _convert_row_as_tuple(self, row: Row) -> tuple:
        return tuple(map(str, row.asDict().values()))

    def _get_dataframe_results(self, df: DataFrame) -> list:
        return list(map(self._convert_row_as_tuple, df.collect()))

    def run(self, command: str, fetch: str = "all") -> str:
        df = self._spark.sql(command)
        if fetch == "one":
            df = df.limit(1)
        return str(self._get_dataframe_results(df))

    def get_table_info_no_throw(self, table_names: Optional[List[str]] = None) -> str:
        """Get information about specified tables.

        Follows best practices as specified in: Rajkumar et al, 2022
        (https://arxiv.org/abs/2204.00498)

        If `sample_rows_in_table_info`, the specified number of sample rows will be
        appended to each table description. This can increase performance as
        demonstrated in the paper.
        """
        try:
            return self.get_table_info(table_names)
        except ValueError as e:
            """Format the error message"""
            return f"Error: {e}"

    def run_no_throw(self, command: str, fetch: str = "all") -> str:
        """Execute a SQL command and return a string representing the results.

        If the statement returns rows, a string of the results is returned.
        If the statement returns no rows, an empty string is returned.

        If the statement throws an error, the error message is returned.
        """
        try:
            return self.run(command, fetch)
        except Exception as e:
            """Format the error message"""
            return f"Error: {e}"