File size: 3,634 Bytes
d25ee4b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import ibis
import sqlglot
from sqlglot import optimizer
from sqlglot.optimizer import qualify
from sqlglot.errors import OptimizeError, ParseError

class Database:

    def __init__(self,connection_url,engine_dialect = "mysql") -> None:

        self._connect_url = connection_url
        self.engine_dialect = engine_dialect
        self._tables_docs = {}
        self._table_exemple = {}

    def connect(self):
        try:
            self._con = ibis.connect(self._connect_url)
            return f"βœ… Connection to {self._connect_url} OK!"
        except Exception as e:
            #raise f"❌ Connection failed: {type(e).__name__} - {str(e)}"
            raise e

    def _optimize_query(self,sql,schema):

        optimized_expression = optimizer.optimize(sql, schema=schema, dialect=self.engine_dialect)
        optimized_sql = optimized_expression.sql(dialect=self.engine_dialect)
        return optimized_sql
    
    def _pretify_table(self,table,columns):
        out = ""
        if table in self._tables_docs.keys():
            out += f"## Documentation \n{self._tables_docs[table]}\n"

        if table in self._table_exemple.keys():
            out += f"## Exemple \n{self._table_exemple[table]}"
        out += f"Table ({table}) with  {len(columns)} fields : \n"
        for field in columns.keys():
            out += f"\t{field} of type : {columns[field]}\n"
        return out

    def add_table_documentation(self,table_name,documentation):
        self._tables_docs[table_name] = documentation
    def add_table_exemple(self,table_name,exemples):
        self._table_exemple[table_name] = exemples
    
    def get_tables_array(self):
        schema = self._build_schema()
        array = []
        for table in schema.keys():
           array.append(self._pretify_table(table,schema[table]))
        return array

    def _pretify_schema(self):
        out = ""
        schema = self._build_schema()
        for table in schema.keys():
           out += self._pretify_table(table,schema[table])
           out += "\n"
        return out
    def _build_schema(self):

        tables = self._con.list_tables()
        schema = {}
        for table_name in tables:

            try:
                table_expr = self._con.table(table_name)
                table_schema = table_expr.schema()
                columns = {col: str(dtype) for col, dtype in table_schema.items()}
                schema[table_name] = columns

            except Exception as e:

                print(f"Warning: Could not retrieve schema for table '{table_name}': {e}")
        return schema

    def query(self, sql_query):
        schema = self._build_schema()
        print(sql_query)
        try:
            expression = sqlglot.parse_one(sql_query, read=self.engine_dialect)
        except Exception as e:
            raise e

        try:            
            optimized_query = self._optimize_query(expression, schema)
            final_query = optimized_query
        except Exception as e:
            final_query = expression.sql(dialect=self.engine_dialect)

        try:
            expr = self._con.sql(final_query, dialect=self.engine_dialect)
            result_df = expr.execute()
            return result_df
        except Exception as e:
            raise e
        

# db = Database("mysql://user:password@localhost:3306/Pokemon")
# db.connect()
# schema = db._build_schema()
# db.add_table_documentation("Defense","This is a super table")
# db.add_table_exemple("Defense","caca")
# db.add_table_exemple("Joueur","ezofkzrfp")
# for table in schema.keys():
#     print(db._pretify_table(table,schema[table]))