from schema_filter import filter_func data = { "text": "Name movie titles released in year 1945. Sort the listing by the descending order of movie popularity.", "sql": "SELECT movie_title FROM movies WHERE movie_release_year = 1945 ORDER BY movie_popularity DESC LIMIT 1", "schema": { "schema_items": [ { "table_name": "lists", "table_comment": "", "column_names": [ "user_id", "list_id", "list_title", "list_movie_number", "list_update_timestamp_utc", "list_creation_timestamp_utc", "list_followers", "list_url", "list_comments", "list_description", "list_cover_image_url", "list_first_image_url", "list_second_image_url", "list_third_image_url" ], "column_comments": [ "", "", "", "", "", "", "", "", "", "", "", "", "", "" ] }, { "table_name": "movies", "table_comment": "", "column_names": [ "movie_id", "movie_title", "movie_release_year", "movie_url", "movie_title_language", "movie_popularity", "movie_image_url", "director_id", "director_name", "director_url" ], "column_comments": [ "", "", "", "", "", "", "", "", "", "" ] }, { "table_name": "ratings_users", "table_comment": "", "column_names": [ "user_id", "rating_date_utc", "user_trialist", "user_subscriber", "user_avatar_image_url", "user_cover_image_url", "user_eligible_for_trial", "user_has_payment_method" ], "column_comments": [ "", "", "", "", "", "", "", "" ] }, { "table_name": "lists_users", "table_comment": "", "column_names": [ "user_id", "list_id", "list_update_date_utc", "list_creation_date_utc", "user_trialist", "user_subscriber", "user_avatar_image_url", "user_cover_image_url", "user_eligible_for_trial", "user_has_payment_method" ], "column_comments": [ "", "", "", "", "", "", "", "", "", "" ] }, { "table_name": "ratings", "table_comment": "", "column_names": [ "movie_id", "rating_id", "rating_url", "rating_score", "rating_timestamp_utc", "critic", "critic_likes", "critic_comments", "user_id", "user_trialist", "user_subscriber", "user_eligible_for_trial", "user_has_payment_method" ], "column_comments": [ "", "", "", "", "", "", "", "", "", "", "", "", "" ] } ] } } def find_used_tables_and_columns(dataset): for data in dataset: sql = data["sql"].lower() data["table_labels"] = [] data["column_labels"] = [] for table_info in data["schema"]["schema_items"]: table_name = table_info["table_name"] data["table_labels"].append(1 if table_name.lower() in sql else 0) data["column_labels"].append([1 if column_name.lower() in sql else 0 \ for column_name in table_info["column_names"]]) return dataset dataset = [data] # 根据sql找到用到的表和列 dataset = find_used_tables_and_columns(dataset) # 最多保留数据库中的6张表 num_top_k_tables = 6 # 对于每张保留的表,最多保留其中6个列,所以输入的prompt中最多有6*6=36个列 num_top_k_columns = 6 # 对于训练数据,我们可以根据sql来模拟filter的过程,这时,sic(schema item classifier)是None就行,不需要用到模型 dataset = filter_func( dataset = dataset, dataset_type = "train", sic = None, num_top_k_tables = num_top_k_tables, num_top_k_columns = num_top_k_columns ) print(dataset)