tdoehmen commited on
Commit
1da0a65
1 Parent(s): 021d6f7

fixed extraction and cleaned output

Browse files
duckdb-nsql/eval/predict.py CHANGED
@@ -294,7 +294,6 @@ def predict(
294
  for prediction, model_response in generated_responses:
295
  prediction = re.sub(r"[\s\t\n]+", " ", prediction)
296
  token_lengths.append(len(tokenizer(prediction).input_ids))
297
- console.print(f"[blue]Prompt:[/blue] {model_response.final_prompt}")
298
  console.print(f"[red]Prediction:[/red] {prediction}")
299
  if data[i].get("query") or data[i].get("sql"):
300
  console.print(
 
294
  for prediction, model_response in generated_responses:
295
  prediction = re.sub(r"[\s\t\n]+", " ", prediction)
296
  token_lengths.append(len(tokenizer(prediction).input_ids))
 
297
  console.print(f"[red]Prediction:[/red] {prediction}")
298
  if data[i].get("query") or data[i].get("sql"):
299
  console.print(
duckdb-nsql/eval/prompt_formatters.py CHANGED
@@ -65,22 +65,77 @@ class RajkumarFormatter:
65
 
66
  @classmethod
67
  def format_model_output(cls, output_sql: str, prompt: str) -> str:
68
- """Format model output."""
69
- time.sleep(10)
70
- clean_sql = (output_sql
71
- .replace('sql\n', '')
72
- .replace('```sql\n', '')
73
- .replace('```duckdb\n', '')
74
- .replace('```\n', '')
75
- .replace('```', '')).strip()
76
-
77
- if clean_sql.find(';') != -1:
78
- clean_sql[:clean_sql.find(';')].strip()
79
-
80
- if not clean_sql.endswith(";"):
81
- clean_sql += ";"
82
-
83
- return clean_sql
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
 
85
  @classmethod
86
  def format_gold_output(cls, output_sql: str) -> str:
@@ -503,7 +558,7 @@ DuckDB Other Keywords:
503
  `DESC`: The keyword `desc` is used in SQL to describe the DESCENDING order in which query results should be sorted, often associated with the `ORDER BY` clause., Examples: ['SELECT name FROM employees ORDER BY salary DESC;', 'CREATE TABLE example (id INTEGER, name VARCHAR);', 'DESCRIBE example;', 'DESCRIBE SELECT * FROM example WHERE id < 10 ORDER BY name DESC;']
504
  `IS`: The "IS" keyword is used in SQL to perform tests on values to check for NULL values or to use as part of statements like IS DISTINCT FROM which can handle NULLs in equality comparisons., Examples: ['SELECT 4 IS DISTINCT FROM NULL;', 'SELECT 4 IS NOT DISTINCT FROM 4;', 'SELECT NULL IS NULL;', 'SELECT NULL IS NOT NULL;']
505
  `IN`: The `IN` keyword is used in SQL to specify a list of discrete values for a column to match against, typically in a `WHERE` clause, allowing for multiple specific conditions to be evaluated at once., Examples: ["SELECT * FROM employees WHERE department IN ('HR', 'Engineering', 'Marketing');", 'SELECT id, name FROM students WHERE grade IN (10, 11, 12);', "DELETE FROM orders WHERE order_status IN ('Cancelled', 'Returned');", "UPDATE items SET status = 'Unavailable' WHERE item_id IN (1001, 1002, 1003);", "SELECT * FROM logs WHERE severity IN ('ERROR', 'CRITICAL') ORDER BY timestamp DESC;"]
506
- `ALL`: The `ALL` keyword in SQL specifies that operations should retain all duplicate rows, as seen in commands like `UNION ALL`, `INTERSECT ALL`, and `EXCEPT ALL`, which follow bag semantics instead of eliminating duplicates., Examples: ['UNION ALL\n\n```sql\nSELECT * FROM range(2) t1(x)\nUNION ALL\nSELECT * FROM range(3) t2(x);\n```\nThis example demonstrates using `UNION ALL` to combine rows from two queries without eliminating duplicates.', 'INTERSECT ALL\n\n```sql\nSELECT unnest([5, 5, 6, 6, 6, 6, 7, 8]) AS x\nINTERSECT ALL\nSELECT unnest([5, 6, 6, 7, 7, 9]);\n```\nThis example shows using `INTERSECT ALL` to select rows that are present in both result sets, keeping duplicate values.', 'EXCEPT ALL\n\n```sql\nSELECT unnest([5, 5, 6, 6, 6, 6, 7, 8]) AS x\nEXCEPT ALL\nSELECT unnest([5, 6, 6, 7, 7, 9]);\n```\nThis example illustrates `EXCEPT ALL`, which selects all rows present in the first query but not in the second, without removing duplicates.', 'ORDER BY ALL\n\n```sql\nSELECT *\nFROM addresses\nORDER BY ALL;\n```\nThis SQL command uses `ORDER BY ALL` to sort the result set by all columns sequentially from left to right.']
507
  `LIKE`: The `LIKE` expression is used to determine if a string matches a specified pattern, allowing wildcard characters such as `_` to represent any single character and `%` to match any sequence of characters., Examples: ["SELECT 'abc' LIKE 'abc'; -- true", "SELECT 'abc' LIKE 'a%'; -- true", "SELECT 'abc' LIKE '_b_'; -- true", "SELECT 'abc' LIKE 'c'; -- false", "SELECT 'abc' LIKE 'c%'; -- false", "SELECT 'abc' LIKE '%c'; -- true", "SELECT 'abc' NOT LIKE '%c'; -- false", "SELECT 'abc' ILIKE '%C'; -- true"]
508
  `IF`: The `IF` keyword is used in conditional constructs, most commonly found in the `IF NOT EXISTS` or `IF EXISTS` clauses in SQL commands to prevent errors or skip operations, such as creating or dropping a database or table when certain conditions are met., Examples: ['CREATE DATABASE IF NOT EXISTS ducks_db;', 'CREATE TABLE IF NOT EXISTS t1 (i INTEGER, j INTEGER);', 'CREATE TABLE t1 (id INTEGER, PRIMARY KEY (id), UNIQUE (id)) IF NOT EXISTS;']
509
  `EXISTS`: The `EXISTS` operator is used to determine if a subquery returns any rows, returning `true` if at least one row exists and `false` otherwise., Examples: ["SELECT EXISTS (FROM grades WHERE course = 'Math') AS math_grades_present;", "SELECT EXISTS (FROM grades WHERE course = 'History') AS history_grades_present;"]
@@ -528,10 +583,9 @@ Question:
528
  Here is the question or an instruction the user provided:
529
  {question}
530
 
531
- Write a DuckDB SQL query for the given question!
532
 
533
  Answer:
534
- ```
535
  """
536
 
537
  @classmethod
 
65
 
66
  @classmethod
67
  def format_model_output(cls, output_sql: str, prompt: str) -> str:
68
+ def clean_code_block(block: str) -> str:
69
+ """Clean a code block by removing markdown syntax and extra whitespace."""
70
+ # Remove markdown indicators and common SQL prefixes
71
+ cleaned = (block
72
+ .replace('```sql\n', '')
73
+ .replace('```duckdb\n', '')
74
+ .replace('```\n', '')
75
+ .replace('```', '')
76
+ .strip())
77
+
78
+ return cleaned
79
+
80
+ def ensure_semicolon(sql: str) -> str:
81
+ """Ensure the SQL query ends with exactly one semicolon."""
82
+ sql = sql.strip()
83
+ # Remove any existing trailing semicolons
84
+ while sql.endswith(';'):
85
+ sql = sql[:-1].strip()
86
+ # Add back exactly one semicolon
87
+ return sql + ";"
88
+
89
+ # First, try to find SQL-specific code blocks
90
+ sql_blocks = []
91
+ start_pos = 0
92
+ while True:
93
+ start = output_sql.find('```sql', start_pos)
94
+ if start == -1:
95
+ start = output_sql.find('```duckdb', start_pos)
96
+ if start == -1:
97
+ break
98
+
99
+ end = output_sql.find('```', start + 4)
100
+ if end == -1:
101
+ break
102
+
103
+ sql_blocks.append(output_sql[start:end+3])
104
+ start_pos = end + 3
105
+
106
+ # If SQL blocks found, use the last one
107
+ if sql_blocks:
108
+ return ensure_semicolon(clean_code_block(sql_blocks[-1]))
109
+
110
+ # If no SQL blocks, look for generic code blocks
111
+ generic_blocks = []
112
+ start_pos = 0
113
+ while True:
114
+ start = output_sql.find('```', start_pos)
115
+ if start == -1:
116
+ break
117
+
118
+ end = output_sql.find('```', start + 3)
119
+ if end == -1:
120
+ break
121
+
122
+ block = output_sql[start:end+3]
123
+ # Skip if this is actually an SQL block (we already handled those)
124
+ if not block.startswith('```sql') and not block.startswith('```duckdb'):
125
+ generic_blocks.append(block)
126
+ start_pos = end + 3
127
+
128
+ # If generic blocks found, use the last one
129
+ if generic_blocks:
130
+ return ensure_semicolon(clean_code_block(generic_blocks[-1]))
131
+
132
+ # If no code blocks found at all, take everything up to first semicolon
133
+ semicolon_pos = output_sql.find(';')
134
+ if semicolon_pos != -1:
135
+ return ensure_semicolon(output_sql[:semicolon_pos].strip())
136
+
137
+ # If no semicolon found, use the entire text
138
+ return ensure_semicolon(output_sql.strip())
139
 
140
  @classmethod
141
  def format_gold_output(cls, output_sql: str) -> str:
 
558
  `DESC`: The keyword `desc` is used in SQL to describe the DESCENDING order in which query results should be sorted, often associated with the `ORDER BY` clause., Examples: ['SELECT name FROM employees ORDER BY salary DESC;', 'CREATE TABLE example (id INTEGER, name VARCHAR);', 'DESCRIBE example;', 'DESCRIBE SELECT * FROM example WHERE id < 10 ORDER BY name DESC;']
559
  `IS`: The "IS" keyword is used in SQL to perform tests on values to check for NULL values or to use as part of statements like IS DISTINCT FROM which can handle NULLs in equality comparisons., Examples: ['SELECT 4 IS DISTINCT FROM NULL;', 'SELECT 4 IS NOT DISTINCT FROM 4;', 'SELECT NULL IS NULL;', 'SELECT NULL IS NOT NULL;']
560
  `IN`: The `IN` keyword is used in SQL to specify a list of discrete values for a column to match against, typically in a `WHERE` clause, allowing for multiple specific conditions to be evaluated at once., Examples: ["SELECT * FROM employees WHERE department IN ('HR', 'Engineering', 'Marketing');", 'SELECT id, name FROM students WHERE grade IN (10, 11, 12);', "DELETE FROM orders WHERE order_status IN ('Cancelled', 'Returned');", "UPDATE items SET status = 'Unavailable' WHERE item_id IN (1001, 1002, 1003);", "SELECT * FROM logs WHERE severity IN ('ERROR', 'CRITICAL') ORDER BY timestamp DESC;"]
561
+ `ALL`: The `ALL` keyword in SQL specifies that operations should retain all duplicate rows, as seen in commands like `UNION ALL`, `INTERSECT ALL`, and `EXCEPT ALL`, which follow bag semantics instead of eliminating duplicates., Examples: ['UNION ALL ```sql\nSELECT * FROM range(2) t1(x)\nUNION ALL\nSELECT * FROM range(3) t2(x);\n```\nThis example demonstrates using `UNION ALL` to combine rows from two queries without eliminating duplicates.', 'INTERSECT ALL ```sql\nSELECT unnest([5, 5, 6, 6, 6, 6, 7, 8]) AS x\nINTERSECT ALL\nSELECT unnest([5, 6, 6, 7, 7, 9]);\n```\nThis example shows using `INTERSECT ALL` to select rows that are present in both result sets, keeping duplicate values.', 'EXCEPT ALL ```sql\nSELECT unnest([5, 5, 6, 6, 6, 6, 7, 8]) AS x\nEXCEPT ALL\nSELECT unnest([5, 6, 6, 7, 7, 9]);\n```\nThis example illustrates `EXCEPT ALL`, which selects all rows present in the first query but not in the second, without removing duplicates.', 'ORDER BY ALL ```sql\nSELECT *\nFROM addresses\nORDER BY ALL;\n```\nThis SQL command uses `ORDER BY ALL` to sort the result set by all columns sequentially from left to right.']
562
  `LIKE`: The `LIKE` expression is used to determine if a string matches a specified pattern, allowing wildcard characters such as `_` to represent any single character and `%` to match any sequence of characters., Examples: ["SELECT 'abc' LIKE 'abc'; -- true", "SELECT 'abc' LIKE 'a%'; -- true", "SELECT 'abc' LIKE '_b_'; -- true", "SELECT 'abc' LIKE 'c'; -- false", "SELECT 'abc' LIKE 'c%'; -- false", "SELECT 'abc' LIKE '%c'; -- true", "SELECT 'abc' NOT LIKE '%c'; -- false", "SELECT 'abc' ILIKE '%C'; -- true"]
563
  `IF`: The `IF` keyword is used in conditional constructs, most commonly found in the `IF NOT EXISTS` or `IF EXISTS` clauses in SQL commands to prevent errors or skip operations, such as creating or dropping a database or table when certain conditions are met., Examples: ['CREATE DATABASE IF NOT EXISTS ducks_db;', 'CREATE TABLE IF NOT EXISTS t1 (i INTEGER, j INTEGER);', 'CREATE TABLE t1 (id INTEGER, PRIMARY KEY (id), UNIQUE (id)) IF NOT EXISTS;']
564
  `EXISTS`: The `EXISTS` operator is used to determine if a subquery returns any rows, returning `true` if at least one row exists and `false` otherwise., Examples: ["SELECT EXISTS (FROM grades WHERE course = 'Math') AS math_grades_present;", "SELECT EXISTS (FROM grades WHERE course = 'History') AS history_grades_present;"]
 
583
  Here is the question or an instruction the user provided:
584
  {question}
585
 
586
+ Write a DuckDB SQL query for the given question! Make sure to only response with the SQL query and wrap it in ```sql\n<ANSWER>``` markdown code tags.
587
 
588
  Answer:
 
589
  """
590
 
591
  @classmethod
duckdb-nsql/eval/text_to_sql.py CHANGED
@@ -219,11 +219,6 @@ def _run_manifest(
219
  ) -> TextToSQLModelResponse:
220
  """Run manifest for prompt format."""
221
  logger.info(f"PARAMS: {manifest_params}")
222
- if isinstance(prompt, list):
223
- for p in prompt:
224
- logger.info(f"PROMPT: {p['role']}: {p['content']}")
225
- else:
226
- logger.info(f"PROMPT: {prompt}")
227
  start_time = time.time()
228
  # Run result
229
  response = cast(
@@ -248,7 +243,6 @@ def _run_manifest(
248
  cast(str, response.get_response()), prompt
249
  )
250
 
251
- logger.info(f"RAW OUTPUT: {response.get_response()}")
252
  for token in stop_sequences:
253
  sql_query = sql_query.split(token)[0]
254
  logger.info(f"OUTPUT: {sql_query}")
 
219
  ) -> TextToSQLModelResponse:
220
  """Run manifest for prompt format."""
221
  logger.info(f"PARAMS: {manifest_params}")
 
 
 
 
 
222
  start_time = time.time()
223
  # Run result
224
  response = cast(
 
243
  cast(str, response.get_response()), prompt
244
  )
245
 
 
246
  for token in stop_sequences:
247
  sql_query = sql_query.split(token)[0]
248
  logger.info(f"OUTPUT: {sql_query}")
evaluation_logic.py CHANGED
@@ -86,8 +86,8 @@ def save_evaluation(inference_api, model_name, prompt_format, custom_prompt, met
86
  def run_prediction(inference_api, model_name, prompt_format, custom_prompt, output_file):
87
  dataset_path = str(eval_dir / "data/dev.json")
88
  table_meta_path = str(eval_dir / "data/tables.json")
89
- stop_tokens = [';']
90
- max_tokens = 32000
91
  temperature = 0
92
  num_beams = -1
93
  manifest_client = inference_api
 
86
  def run_prediction(inference_api, model_name, prompt_format, custom_prompt, output_file):
87
  dataset_path = str(eval_dir / "data/dev.json")
88
  table_meta_path = str(eval_dir / "data/tables.json")
89
+ stop_tokens = ['`<|dummy|>`']
90
+ max_tokens = 1000
91
  temperature = 0
92
  num_beams = -1
93
  manifest_client = inference_api
requirements.txt CHANGED
@@ -28,4 +28,4 @@ diffusers
28
  deepspeed
29
  sentence_transformers
30
  tqdm
31
- pydantic
 
28
  deepspeed
29
  sentence_transformers
30
  tqdm
31
+ pydantic