from schema_filter import filter_func, SchemaItemClassifierInference # 在eval模式下,sql不用提供 data = { "text": "Name movie titles released in year 1945. Sort the listing by the descending order of movie popularity.", "sql": "", "schema": { "schema_items": [ { "table_name": "lists", "table_comment": "", "column_names": [ "user_id", "list_id", "list_title", "list_movie_number", "list_update_timestamp_utc", "list_creation_timestamp_utc", "list_followers", "list_url", "list_comments", "list_description", "list_cover_image_url", "list_first_image_url", "list_second_image_url", "list_third_image_url" ], "column_comments": [ "", "", "", "", "", "", "", "", "", "", "", "", "", "" ] }, { "table_name": "movies", "table_comment": "", "column_names": [ "movie_id", "movie_title", "movie_release_year", "movie_url", "movie_title_language", "movie_popularity", "movie_image_url", "director_id", "director_name", "director_url" ], "column_comments": [ "", "", "", "", "", "", "", "", "", "" ] }, { "table_name": "ratings_users", "table_comment": "", "column_names": [ "user_id", "rating_date_utc", "user_trialist", "user_subscriber", "user_avatar_image_url", "user_cover_image_url", "user_eligible_for_trial", "user_has_payment_method" ], "column_comments": [ "", "", "", "", "", "", "", "" ] }, { "table_name": "lists_users", "table_comment": "", "column_names": [ "user_id", "list_id", "list_update_date_utc", "list_creation_date_utc", "user_trialist", "user_subscriber", "user_avatar_image_url", "user_cover_image_url", "user_eligible_for_trial", "user_has_payment_method" ], "column_comments": [ "", "", "", "", "", "", "", "", "", "" ] }, { "table_name": "ratings", "table_comment": "", "column_names": [ "movie_id", "rating_id", "rating_url", "rating_score", "rating_timestamp_utc", "critic", "critic_likes", "critic_comments", "user_id", "user_trialist", "user_subscriber", "user_eligible_for_trial", "user_has_payment_method" ], "column_comments": [ "", "", "", "", "", "", "", "", "", "", "", "", "" ] } ] } } dataset = [data] # 最多保留数据库中的7张表 num_top_k_tables = 7 # 对于每张保留的表,最多保留其中20个列,所以输入的prompt中最多有7*10=70个列 num_top_k_columns = 10 # 加载分类器模型 sic = SchemaItemClassifierInference("sic_merged") # 对于测试数据,我们需要加载训练好的分类器,根据用户问题对表和列打分 dataset = filter_func( dataset = dataset, dataset_type = "eval", sic = sic, num_top_k_tables = num_top_k_tables, num_top_k_columns = num_top_k_columns ) print(dataset)