antonlabate commited on
Commit
3124aa4
1 Parent(s): 082b881
generate_text2sql_dataset_amr.sh ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ set -e
2
+
3
+ # generate text2sql training dataset with noise_rate 0.2
4
+ python text2sql_data_generator.py \
5
+ --input_dataset_path "./data/preprocessed_data/preprocessed_train_spider_amr.json" \
6
+ --output_dataset_path "./data/preprocessed_data/resdsql_train_spider_amr.json" \
7
+ --topk_table_num 4 \
8
+ --topk_column_num 5 \
9
+ --mode "train" \
10
+ --noise_rate 0.2 \
11
+ --use_contents \
12
+ --add_fk_info \
13
+ --output_skeleton \
14
+ --target_type "sql"
15
+
16
+ # predict probability for each schema item in the eval set
17
+ python schema_item_classifier.py \
18
+ --batch_size 32 \
19
+ --device "0" \
20
+ --seed 42 \
21
+ --save_path "./models/text2sql_schema_item_classifier_semantic" \
22
+ --dev_filepath "./data/preprocessed_data/preprocessed_dev_amr.json" \
23
+ --output_filepath "./data/preprocessed_data/dev_with_probs_amr.json" \
24
+ --use_contents \
25
+ --add_fk_info \
26
+ --mode "eval"
27
+
28
+ # generate text2sql development dataset
29
+ python text2sql_data_generator.py \
30
+ --input_dataset_path "./data/preprocessed_data/dev_with_probs_amr.json" \
31
+ --output_dataset_path "./data/preprocessed_data/resdsql_dev_amr.json" \
32
+ --topk_table_num 4 \
33
+ --topk_column_num 5 \
34
+ --mode "eval" \
35
+ --use_contents \
36
+ --add_fk_info \
37
+ --output_skeleton \
38
+ --target_type "sql"
preprocess.sh ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ set -e
2
+
3
+ # preprocess train_spider dataset
4
+ python preprocessing.py \
5
+ --mode "train" \
6
+ --table_path "./data/spider_amr/tables.json" \
7
+ --input_dataset_path "./data/spider_amr/train_spider.json" \
8
+ --output_dataset_path "./data/preprocessed_data/preprocessed_train_spider_amr.json" \
9
+ --db_path "./database" \
10
+ --target_type "sql"
11
+
12
+ # preprocess dev dataset
13
+ python preprocessing.py \
14
+ --mode "eval" \
15
+ --table_path "./data/spider_amr/tables.json" \
16
+ --input_dataset_path "./data/spider_amr/dev.json" \
17
+ --output_dataset_path "./data/preprocessed_data/preprocessed_dev_amr.json" \
18
+ --db_path "./database"\
19
+ --target_type "sql"
train_text2sql_schema_item_classifier.sh ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ set -e
2
+
3
+ # train schema item classifier
4
+ python -u schema_item_classifier.py \
5
+ --batch_size 8 \
6
+ --gradient_descent_step 2 \
7
+ --device "0" \
8
+ --learning_rate 1e-5 \
9
+ --gamma 2.0 \
10
+ --alpha 0.75 \
11
+ --epochs 32 \
12
+ --patience 16 \
13
+ --seed 42 \
14
+ --save_path "./models/text2sql_schema_item_classifier_semantic" \
15
+ --tensorboard_save_path "./tensorboard_log/text2sql_schema_item_classifier_semantic" \
16
+ --train_filepath "./data/preprocessed_data/preprocessed_train_spider_amr.json" \
17
+ --dev_filepath "./data/preprocessed_data/preprocessed_dev_amr.json" \
18
+ --model_name_or_path "roberta-large" \
19
+ --use_contents \
20
+ --add_fk_info \
21
+ --mode "train"
train_text2sql_t5_base.sh ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ set -e
2
+
3
+ # train text2sql-t5-base model
4
+ python -u text2sql_inputgrande.py \
5
+ --batch_size 8 \
6
+ --gradient_descent_step 2 \
7
+ --device "0" \
8
+ --learning_rate 1e-4 \
9
+ --epochs 128 \
10
+ --seed 42 \
11
+ --save_path "./models/text2sql-t5-amr" \
12
+ --tensorboard_save_path "./tensorboard_log/text2sql-t5-amr" \
13
+ --model_name_or_path "t5-base" \
14
+ --use_adafactor \
15
+ --mode train \
16
+ --train_filepath "./data/preprocessed_data/resdsql_train_spider_amr.json"
17
+
18
+ # select the best text2sql-t5-base ckpt
19
+ python -u evaluate_text2sql_ckpts_inputgrande.py \
20
+ --batch_size 8 \
21
+ --device "0" \
22
+ --seed 42 \
23
+ --save_path "./models/text2sql-t5-amr" \
24
+ --eval_results_path "./eval_results/text2sql-t5-amr" \
25
+ --mode eval \
26
+ --dev_filepath "./data/preprocessed_data/resdsql_dev_amr.json" \
27
+ --original_dev_filepath "./data/spider_amr/dev.json" \
28
+ --db_path "./database" \
29
+ --num_beams 8 \
30
+ --num_return_sequences 8 \
31
+ --target_type "sql"