AshenH commited on
Commit
b07564d
·
verified ·
1 Parent(s): 9d6bac9

Update tools/sql_tool.py

Browse files
Files changed (1) hide show
  1. tools/sql_tool.py +28 -20
tools/sql_tool.py CHANGED
@@ -1,7 +1,12 @@
 
1
  import os
2
  import re
 
 
 
3
  import pandas as pd
4
  from typing import Optional
 
5
  from utils.config import AppConfig
6
  from utils.tracing import Tracer
7
 
@@ -21,23 +26,19 @@ class SQLTool:
21
  raise RuntimeError("Missing GCP_SERVICE_ACCOUNT_JSON secret")
22
 
23
  # Accept full JSON string from Space Secret
24
- if key_json.strip().startswith("{"):
25
- import json
26
- info = json.loads(key_json)
27
- else:
28
- info = {}
29
-
30
  creds = service_account.Credentials.from_service_account_info(info)
31
  self.client = bigquery.Client(credentials=creds, project=cfg.gcp_project)
32
 
33
  elif self.backend == "motherduck":
34
- import duckdb, shutil, glob
35
- import os
36
 
 
37
  if not duckdb.__version__.startswith("1.3.2"):
38
  raise RuntimeError(
39
- "MotherDuck currently supports DuckDB 1.3.2. "
40
- "Pin duckdb==1.3.2 in requirements.txt and redeploy."
 
41
  )
42
 
43
  token = self.cfg.motherduck_token or os.getenv("MOTHERDUCK_TOKEN")
@@ -45,17 +46,18 @@ class SQLTool:
45
  if not token:
46
  raise RuntimeError("Missing MOTHERDUCK_TOKEN")
47
 
 
48
  try:
49
  ext_root = os.path.expanduser("~/.duckdb/extensions")
50
  for p in glob.glob(os.path.join(ext_root, "*")):
51
- if "1.3.2" not in p: # keep current version caches, remove others
52
  shutil.rmtree(p, ignore_errors=True)
53
  except Exception:
 
 
54
 
55
- # Plain DuckDB connection
56
- self.client = duckdb.connect()
57
-
58
- # Ensure MotherDuck extension is available and loaded
59
  self.client.execute("INSTALL motherduck;")
60
  self.client.execute("LOAD motherduck;")
61
 
@@ -70,18 +72,21 @@ class SQLTool:
70
  """
71
  Minimal NL2SQL heuristic; replace with your own mapping or LLM prompt.
72
  Expect users to include table names. Example:
73
- "avg metric by month from analytics.events"
74
  """
75
  m = message.lower()
76
 
77
- # Very basic template example (edit to your tables/columns)
78
  if "avg" in m and " by " in m:
 
 
79
  return (
80
  "-- Example template; edit me\n"
81
  "SELECT DATE_TRUNC('month', date_col) AS month, "
82
  "AVG(metric) AS avg_metric "
83
  "FROM analytics.table "
84
- "GROUP BY 1 ORDER BY 1;"
 
85
  )
86
 
87
  # Pass-through if the user typed SQL explicitly
@@ -93,12 +98,15 @@ class SQLTool:
93
 
94
  def run(self, message: str) -> pd.DataFrame:
95
  sql = self._nl_to_sql(message)
96
- self.tracer.trace_event("sql_query", {"sql": sql, "backend": self.backend})
 
 
 
97
 
98
  if self.backend == "bigquery":
99
  df = self.client.query(sql).to_dataframe()
100
  else:
101
- # DuckDB (MotherDuck): fetch_df returns a pandas DataFrame
102
  df = self.client.execute(sql).fetch_df()
103
 
104
  return df
 
1
+ # space/tools/sql_tool.py
2
  import os
3
  import re
4
+ import json
5
+ import shutil
6
+ import glob
7
  import pandas as pd
8
  from typing import Optional
9
+
10
  from utils.config import AppConfig
11
  from utils.tracing import Tracer
12
 
 
26
  raise RuntimeError("Missing GCP_SERVICE_ACCOUNT_JSON secret")
27
 
28
  # Accept full JSON string from Space Secret
29
+ info = json.loads(key_json) if key_json.strip().startswith("{") else {}
 
 
 
 
 
30
  creds = service_account.Credentials.from_service_account_info(info)
31
  self.client = bigquery.Client(credentials=creds, project=cfg.gcp_project)
32
 
33
  elif self.backend == "motherduck":
34
+ import duckdb
 
35
 
36
+ # ---- Enforce supported DuckDB version for MotherDuck extension ----
37
  if not duckdb.__version__.startswith("1.3.2"):
38
  raise RuntimeError(
39
+ f"Incompatible DuckDB version {duckdb.__version__}. "
40
+ "MotherDuck currently supports DuckDB 1.3.2. "
41
+ "Pin duckdb==1.3.2 in requirements.txt and redeploy."
42
  )
43
 
44
  token = self.cfg.motherduck_token or os.getenv("MOTHERDUCK_TOKEN")
 
46
  if not token:
47
  raise RuntimeError("Missing MOTHERDUCK_TOKEN")
48
 
49
+ # ---- Clean stale extension caches compiled for other DuckDB versions ----
50
  try:
51
  ext_root = os.path.expanduser("~/.duckdb/extensions")
52
  for p in glob.glob(os.path.join(ext_root, "*")):
53
+ if "1.3.2" not in p: # keep only current version caches
54
  shutil.rmtree(p, ignore_errors=True)
55
  except Exception:
56
+ # best-effort cleanup; proceed even if it fails
57
+ pass
58
 
59
+ # ---- Connect & load MotherDuck extension ----
60
+ self.client = duckdb.connect() # in-memory connection; we'll ATTACH MotherDuck
 
 
61
  self.client.execute("INSTALL motherduck;")
62
  self.client.execute("LOAD motherduck;")
63
 
 
72
  """
73
  Minimal NL2SQL heuristic; replace with your own mapping or LLM prompt.
74
  Expect users to include table names. Example:
75
+ "avg metric by month from analytics.events"
76
  """
77
  m = message.lower()
78
 
79
+ # Very basic template example (edit table/columns to your schema)
80
  if "avg" in m and " by " in m:
81
+ # DuckDB uses DATE_TRUNC('month', col); BigQuery uses DATE_TRUNC(col, MONTH).
82
+ # This generic SQL should work in DuckDB/MotherDuck; adapt if using BigQuery.
83
  return (
84
  "-- Example template; edit me\n"
85
  "SELECT DATE_TRUNC('month', date_col) AS month, "
86
  "AVG(metric) AS avg_metric "
87
  "FROM analytics.table "
88
+ "GROUP BY 1 "
89
+ "ORDER BY 1;"
90
  )
91
 
92
  # Pass-through if the user typed SQL explicitly
 
98
 
99
  def run(self, message: str) -> pd.DataFrame:
100
  sql = self._nl_to_sql(message)
101
+ try:
102
+ self.tracer.trace_event("sql_query", {"sql": sql, "backend": self.backend})
103
+ except Exception:
104
+ pass
105
 
106
  if self.backend == "bigquery":
107
  df = self.client.query(sql).to_dataframe()
108
  else:
109
+ # DuckDB (MotherDuck)
110
  df = self.client.execute(sql).fetch_df()
111
 
112
  return df