tdoehmen commited on
Commit
b9dc6d6
1 Parent(s): d9c57da

no subprocess

Browse files
Files changed (3) hide show
  1. app.py +10 -146
  2. evaluation_logic.py +193 -0
  3. requirements.txt +2 -0
app.py CHANGED
@@ -1,157 +1,21 @@
1
  import gradio as gr
2
- import os
3
- import sys
4
- from pathlib import Path
5
- from datetime import datetime
6
- import json
7
 
8
- # Add the duckdb-nsql directory to the Python path
9
- current_dir = Path(__file__).resolve().parent
10
- duckdb_nsql_dir = current_dir / 'duckdb-nsql'
11
- eval_dir = duckdb_nsql_dir / 'eval'
12
- sys.path.extend([str(current_dir), str(duckdb_nsql_dir), str(eval_dir)])
13
-
14
- # Import necessary functions and classes from predict.py and evaluate.py
15
- from eval.predict import predict, console, get_manifest, DefaultLoader
16
- from eval.constants import PROMPT_FORMATTERS
17
- from eval.evaluate import evaluate, compute_metrics, get_to_print
18
- from eval.evaluate import test_suite_evaluation, read_tables_json
19
-
20
-
21
- def run_evaluation(model_name):
22
- results = []
23
-
24
- if "OPENROUTER_API_KEY" not in os.environ:
25
- return "Error: OPENROUTER_API_KEY not found in environment variables."
26
-
27
- try:
28
- # Set up the arguments similar to the CLI in predict.py
29
- dataset_path = "duckdb-nsql/eval/data/dev.json"
30
- table_meta_path = "duckdb-nsql/eval/data/tables.json"
31
- output_dir = "duckdb-nsql/output/"
32
- prompt_format = "duckdbinstgraniteshort"
33
- stop_tokens = [';']
34
- max_tokens = 30000
35
- temperature = 0.1
36
- num_beams = -1
37
- manifest_client = "openrouter"
38
- manifest_engine = model_name
39
- manifest_connection = "http://localhost:5000"
40
- overwrite_manifest = True
41
- parallel = False
42
-
43
- # Initialize necessary components
44
- data_formatter = DefaultLoader()
45
- prompt_formatter = PROMPT_FORMATTERS[prompt_format]()
46
-
47
- # Load manifest
48
- manifest = get_manifest(
49
- manifest_client=manifest_client,
50
- manifest_connection=manifest_connection,
51
- manifest_engine=manifest_engine,
52
- )
53
-
54
- results.append(f"Using model: {manifest_engine}")
55
-
56
- # Load data and metadata
57
- results.append("Loading metadata and data...")
58
- db_to_tables = data_formatter.load_table_metadata(table_meta_path)
59
- data = data_formatter.load_data(dataset_path)
60
-
61
- # Generate output filename
62
- date_today = datetime.now().strftime("%y-%m-%d")
63
- pred_filename = f"{prompt_format}_0docs_{manifest_engine.split('/')[-1]}_{Path(dataset_path).stem}_{date_today}.json"
64
- pred_path = Path(output_dir) / pred_filename
65
- results.append(f"Prediction will be saved to: {pred_path}")
66
-
67
- # Debug: Print predict function signature
68
- yield f"Predict function signature: {inspect.signature(predict)}"
69
-
70
- # Run prediction
71
- yield "Starting prediction..."
72
- try:
73
- predict(
74
- dataset_path=dataset_path,
75
- table_meta_path=table_meta_path,
76
- output_dir=output_dir,
77
- prompt_format=prompt_format,
78
- stop_tokens=stop_tokens,
79
- max_tokens=max_tokens,
80
- temperature=temperature,
81
- num_beams=num_beams,
82
- manifest_client=manifest_client,
83
- manifest_engine=manifest_engine,
84
- manifest_connection=manifest_connection,
85
- overwrite_manifest=overwrite_manifest,
86
- parallel=parallel
87
- )
88
- except TypeError as e:
89
- yield f"TypeError in predict function: {str(e)}"
90
- yield "Attempting to call predict with only expected arguments..."
91
- # Try calling predict with only the arguments it expects
92
- predict_args = inspect.getfullargspec(predict).args
93
- filtered_args = {k: v for k, v in locals().items() if k in predict_args}
94
- predict(**filtered_args)
95
-
96
- results.append("Prediction completed.")
97
-
98
- # Run evaluation
99
- results.append("Starting evaluation...")
100
-
101
- # Set up evaluation arguments
102
- gold_path = Path(dataset_path)
103
- db_dir = "duckdb-nsql/eval/data/databases/"
104
- tables_path = Path(table_meta_path)
105
-
106
- kmaps = test_suite_evaluation.build_foreign_key_map_from_json(str(tables_path))
107
- db_schemas = read_tables_json(str(tables_path))
108
-
109
- gold_sqls_dict = json.load(gold_path.open("r", encoding="utf-8"))
110
- pred_sqls_dict = [json.loads(l) for l in pred_path.open("r").readlines()]
111
-
112
- gold_sqls = [p.get("query", p.get("sql", "")) for p in gold_sqls_dict]
113
- setup_sqls = [p["setup_sql"] for p in gold_sqls_dict]
114
- validate_sqls = [p["validation_sql"] for p in gold_sqls_dict]
115
- gold_dbs = [p.get("db_id", p.get("db", "")) for p in gold_sqls_dict]
116
- pred_sqls = [p["pred"] for p in pred_sqls_dict]
117
- categories = [p.get("category", "") for p in gold_sqls_dict]
118
-
119
- metrics = compute_metrics(
120
- gold_sqls=gold_sqls,
121
- pred_sqls=pred_sqls,
122
- gold_dbs=gold_dbs,
123
- setup_sqls=setup_sqls,
124
- validate_sqls=validate_sqls,
125
- kmaps=kmaps,
126
- db_schemas=db_schemas,
127
- database_dir=db_dir,
128
- lowercase_schema_match=False,
129
- model_name=model_name,
130
- categories=categories,
131
- )
132
-
133
- results.append("Evaluation completed.")
134
-
135
- # Format and add the evaluation metrics to the results
136
- if metrics:
137
- to_print = get_to_print({"all": metrics}, "all", model_name, len(gold_sqls))
138
- formatted_metrics = "\n".join([f"{k}: {v}" for k, v in to_print.items() if k not in ["slice", "model"]])
139
- results.append(f"Evaluation metrics:\n{formatted_metrics}")
140
- else:
141
- results.append("No evaluation metrics returned.")
142
-
143
- except Exception as e:
144
- results.append(f"An unexpected error occurred: {str(e)}")
145
-
146
- return "\n\n".join(results)
147
 
148
  with gr.Blocks() as demo:
149
  gr.Markdown("# DuckDB SQL Evaluation App")
150
 
151
  model_name = gr.Textbox(label="Model Name (e.g., qwen/qwen-2.5-72b-instruct)")
 
 
 
 
 
152
  start_btn = gr.Button("Start Evaluation")
153
  output = gr.Textbox(label="Output", lines=20)
154
 
155
- start_btn.click(fn=run_evaluation, inputs=[model_name], outputs=output)
156
 
157
- demo.launch()
 
1
  import gradio as gr
2
+ from evaluation_logic import run_evaluation, AVAILABLE_PROMPT_FORMATS
 
 
 
 
3
 
4
+ def gradio_run_evaluation(model_name, prompt_format):
5
+ return run_evaluation(model_name, prompt_format)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
  with gr.Blocks() as demo:
8
  gr.Markdown("# DuckDB SQL Evaluation App")
9
 
10
  model_name = gr.Textbox(label="Model Name (e.g., qwen/qwen-2.5-72b-instruct)")
11
+ prompt_format = gr.Dropdown(
12
+ label="Prompt Format",
13
+ choices=AVAILABLE_PROMPT_FORMATS,
14
+ value="duckdbinstgraniteshort"
15
+ )
16
  start_btn = gr.Button("Start Evaluation")
17
  output = gr.Textbox(label="Output", lines=20)
18
 
19
+ start_btn.click(fn=gradio_run_evaluation, inputs=[model_name, prompt_format], outputs=output)
20
 
21
+ demo.queue().launch()
evaluation_logic.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ from pathlib import Path
4
+ from datetime import datetime
5
+ import json
6
+ import traceback
7
+
8
+ # Add the necessary directories to the Python path
9
+ current_dir = Path(__file__).resolve().parent
10
+ duckdb_nsql_dir = current_dir / 'duckdb-nsql'
11
+ eval_dir = duckdb_nsql_dir / 'eval'
12
+ sys.path.extend([str(current_dir), str(duckdb_nsql_dir), str(eval_dir)])
13
+
14
+ # Import necessary functions and classes
15
+ from eval.predict import get_manifest, DefaultLoader, PROMPT_FORMATTERS, generate_sql
16
+ from eval.evaluate import evaluate, compute_metrics, get_to_print
17
+ from eval.evaluate import test_suite_evaluation, read_tables_json
18
+ from eval.schema import TextToSQLParams, Table
19
+
20
+ AVAILABLE_PROMPT_FORMATS = list(PROMPT_FORMATTERS.keys())
21
+
22
+ def run_prediction(model_name, prompt_format, output_file):
23
+ dataset_path = str(eval_dir / "data/dev.json")
24
+ table_meta_path = str(eval_dir / "data/tables.json")
25
+ stop_tokens = [';']
26
+ max_tokens = 30000
27
+ temperature = 0.1
28
+ num_beams = -1
29
+ manifest_client = "openrouter"
30
+ manifest_engine = model_name
31
+ manifest_connection = "http://localhost:5000"
32
+ overwrite_manifest = True
33
+ parallel = False
34
+
35
+ yield "Starting prediction..."
36
+
37
+ try:
38
+ # Initialize necessary components
39
+ data_formatter = DefaultLoader()
40
+ prompt_formatter = PROMPT_FORMATTERS[prompt_format]()
41
+
42
+ # Load manifest
43
+ manifest = get_manifest(
44
+ manifest_client=manifest_client,
45
+ manifest_connection=manifest_connection,
46
+ manifest_engine=manifest_engine,
47
+ )
48
+
49
+ # Load data
50
+ data = data_formatter.load_data(dataset_path)
51
+ db_to_tables = data_formatter.load_table_metadata(table_meta_path)
52
+
53
+ # Prepare input for generate_sql
54
+ text_to_sql_inputs = []
55
+ for input_question in data:
56
+ question = input_question["question"]
57
+ db_id = input_question.get("db_id", "none")
58
+ if db_id != "none":
59
+ table_params = list(db_to_tables.get(db_id, {}).values())
60
+ else:
61
+ table_params = []
62
+
63
+ if len(table_params) == 0:
64
+ yield f"[red] WARNING: No tables found for {db_id} [/red]"
65
+
66
+ text_to_sql_inputs.append(TextToSQLParams(
67
+ instruction=question,
68
+ database=db_id,
69
+ tables=table_params,
70
+ ))
71
+
72
+ # Generate SQL
73
+ generated_sqls = generate_sql(
74
+ manifest=manifest,
75
+ text_to_sql_in=text_to_sql_inputs,
76
+ retrieved_docs=[[] for _ in text_to_sql_inputs], # Assuming no retrieved docs
77
+ prompt_formatter=prompt_formatter,
78
+ stop_tokens=stop_tokens,
79
+ overwrite_manifest=overwrite_manifest,
80
+ max_tokens=max_tokens,
81
+ temperature=temperature,
82
+ num_beams=num_beams,
83
+ parallel=parallel
84
+ )
85
+
86
+ # Save results
87
+ with output_file.open('w') as f:
88
+ for original_data, (sql, _) in zip(data, generated_sqls):
89
+ output = {**original_data, "pred": sql}
90
+ json.dump(output, f)
91
+ f.write('\n')
92
+
93
+ yield f"Prediction completed. Results saved to {output_file}"
94
+ except Exception as e:
95
+ yield f"Prediction failed with error: {str(e)}"
96
+ yield f"Error traceback: {traceback.format_exc()}"
97
+
98
+ def run_evaluation(model_name, prompt_format="duckdbinstgraniteshort"):
99
+ if "OPENROUTER_API_KEY" not in os.environ:
100
+ yield "Error: OPENROUTER_API_KEY not found in environment variables."
101
+ return
102
+
103
+ try:
104
+ # Set up the arguments
105
+ dataset_path = str(eval_dir / "data/dev.json")
106
+ table_meta_path = str(eval_dir / "data/tables.json")
107
+ output_dir = eval_dir / "output"
108
+
109
+ yield f"Using model: {model_name}"
110
+ yield f"Using prompt format: {prompt_format}"
111
+
112
+ output_file = output_dir / f"{prompt_format}_0docs_{model_name.trim().replace('/', '_')}_dev_{datetime.now().strftime('%y-%m-%d')}.json"
113
+
114
+ # Ensure the output directory exists
115
+ output_dir.mkdir(parents=True, exist_ok=True)
116
+
117
+ if output_file.exists():
118
+ yield f"Prediction file already exists: {output_file}"
119
+ yield "Skipping prediction step and proceeding to evaluation."
120
+ else:
121
+ # Run prediction
122
+ for output in run_prediction(model_name, prompt_format, output_file):
123
+ yield output
124
+
125
+ # Run evaluation
126
+ yield "Starting evaluation..."
127
+
128
+ # Set up evaluation arguments
129
+ gold_path = Path(dataset_path)
130
+ db_dir = str(eval_dir / "data/databases/")
131
+ tables_path = Path(table_meta_path)
132
+
133
+ kmaps = test_suite_evaluation.build_foreign_key_map_from_json(str(tables_path))
134
+ db_schemas = read_tables_json(str(tables_path))
135
+
136
+ gold_sqls_dict = json.load(gold_path.open("r", encoding="utf-8"))
137
+ pred_sqls_dict = [json.loads(l) for l in output_file.open("r").readlines()]
138
+
139
+ gold_sqls = [p.get("query", p.get("sql", "")) for p in gold_sqls_dict]
140
+ setup_sqls = [p["setup_sql"] for p in gold_sqls_dict]
141
+ validate_sqls = [p["validation_sql"] for p in gold_sqls_dict]
142
+ gold_dbs = [p.get("db_id", p.get("db", "")) for p in gold_sqls_dict]
143
+ pred_sqls = [p["pred"] for p in pred_sqls_dict]
144
+ categories = [p.get("category", "") for p in gold_sqls_dict]
145
+
146
+ yield "Computing metrics..."
147
+ metrics = compute_metrics(
148
+ gold_sqls=gold_sqls,
149
+ pred_sqls=pred_sqls,
150
+ gold_dbs=gold_dbs,
151
+ setup_sqls=setup_sqls,
152
+ validate_sqls=validate_sqls,
153
+ kmaps=kmaps,
154
+ db_schemas=db_schemas,
155
+ database_dir=db_dir,
156
+ lowercase_schema_match=False,
157
+ model_name=model_name,
158
+ categories=categories,
159
+ )
160
+
161
+ yield "Evaluation completed."
162
+
163
+ if metrics:
164
+ yield "Overall Results:"
165
+ overall_metrics = metrics['exec']['all']
166
+ yield f"Count: {overall_metrics['count']}"
167
+ yield f"Execution Accuracy: {overall_metrics['exec']:.3f}"
168
+ yield f"Exact Match Accuracy: {overall_metrics['exact']:.3f}"
169
+ yield f"Equality: {metrics['equality']['equality']:.3f}"
170
+ yield f"Edit Distance: {metrics['edit_distance']['edit_distance']:.3f}"
171
+
172
+ yield "\nResults by Category:"
173
+ categories = ['easy', 'medium', 'hard', 'duckdb', 'ddl', 'all']
174
+
175
+ for category in categories:
176
+ if category in metrics['exec']:
177
+ yield f"\n{category}:"
178
+ category_metrics = metrics['exec'][category]
179
+ yield f"Count: {category_metrics['count']}"
180
+ yield f"Execution Accuracy: {category_metrics['exec']:.3f}"
181
+ else:
182
+ yield f"\n{category}: No data available"
183
+ else:
184
+ yield "No evaluation metrics returned."
185
+ except Exception as e:
186
+ yield f"An unexpected error occurred: {str(e)}"
187
+ yield f"Error traceback: {traceback.format_exc()}"
188
+
189
+ if __name__ == "__main__":
190
+ model_name = input("Enter the model name: ")
191
+ prompt_format = input("Enter the prompt format (default is duckdbinstgraniteshort): ") or "duckdbinstgraniteshort"
192
+ for result in run_evaluation(model_name, prompt_format):
193
+ print(result, flush=True)
requirements.txt CHANGED
@@ -20,6 +20,7 @@ peft==0.6.0
20
  packaging==23.2
21
  ninja==1.11.1.1
22
  langchain
 
23
  pydantic
24
  packaging
25
  #./duckdb-nsql/manifest
@@ -28,3 +29,4 @@ flask
28
  diffusers
29
  deepspeed
30
  sentence_transformers
 
 
20
  packaging==23.2
21
  ninja==1.11.1.1
22
  langchain
23
+ gradio
24
  pydantic
25
  packaging
26
  #./duckdb-nsql/manifest
 
29
  diffusers
30
  deepspeed
31
  sentence_transformers
32
+ tqdm