tdoehmen commited on
Commit
c47beda
·
verified ·
1 Parent(s): 77490e0

Update duckdb-nsql/eval/prompt_formatters.py

Browse files
duckdb-nsql/eval/prompt_formatters.py CHANGED
@@ -105,7 +105,7 @@ class RajkumarFormatter:
105
 
106
  # If SQL blocks found, use the last one
107
  if sql_blocks:
108
- return ensure_semicolon(clean_code_block(sql_blocks[-1]))
109
 
110
  # If no SQL blocks, look for generic code blocks
111
  generic_blocks = []
@@ -127,17 +127,15 @@ class RajkumarFormatter:
127
 
128
  # If generic blocks found, use the last one
129
  if generic_blocks:
130
- return ensure_semicolon(clean_code_block(generic_blocks[-1]))
131
 
132
  # If no code blocks found at all, take everything up to first semicolon
133
  semicolon_pos = output_sql.find(';')
134
  if semicolon_pos != -1:
135
- return ensure_semicolon(output_sql[:semicolon_pos].strip())
136
 
137
  # If no semicolon found, use the entire text
138
- extracted = ensure_semicolon(output_sql.strip())
139
- extracted = extracted.replace('```sql','').replace('```','').strip()
140
- return extracted
141
 
142
  @classmethod
143
  def format_gold_output(cls, output_sql: str) -> str:
 
105
 
106
  # If SQL blocks found, use the last one
107
  if sql_blocks:
108
+ return ensure_semicolon(clean_code_block(sql_blocks[-1])).replace('```sql','').replace('```','').strip()
109
 
110
  # If no SQL blocks, look for generic code blocks
111
  generic_blocks = []
 
127
 
128
  # If generic blocks found, use the last one
129
  if generic_blocks:
130
+ return ensure_semicolon(clean_code_block(generic_blocks[-1])).replace('```sql','').replace('```','').strip()
131
 
132
  # If no code blocks found at all, take everything up to first semicolon
133
  semicolon_pos = output_sql.find(';')
134
  if semicolon_pos != -1:
135
+ return ensure_semicolon(output_sql[:semicolon_pos].strip()).replace('```sql','').replace('```','').strip()
136
 
137
  # If no semicolon found, use the entire text
138
+ return ensure_semicolon(output_sql.strip()).replace('```sql','').replace('```','').strip()
 
 
139
 
140
  @classmethod
141
  def format_gold_output(cls, output_sql: str) -> str: