from schema_filter import filter_func, SchemaItemClassifierInference | |
# 在eval模式下,sql不用提供 | |
data = { | |
"text": "Name movie titles released in year 1945. Sort the listing by the descending order of movie popularity.", | |
"sql": "", | |
"schema": { | |
"schema_items": [ | |
{ | |
"table_name": "lists", | |
"table_comment": "", | |
"column_names": [ | |
"user_id", | |
"list_id", | |
"list_title", | |
"list_movie_number", | |
"list_update_timestamp_utc", | |
"list_creation_timestamp_utc", | |
"list_followers", | |
"list_url", | |
"list_comments", | |
"list_description", | |
"list_cover_image_url", | |
"list_first_image_url", | |
"list_second_image_url", | |
"list_third_image_url" | |
], | |
"column_comments": [ | |
"", | |
"", | |
"", | |
"", | |
"", | |
"", | |
"", | |
"", | |
"", | |
"", | |
"", | |
"", | |
"", | |
"" | |
] | |
}, | |
{ | |
"table_name": "movies", | |
"table_comment": "", | |
"column_names": [ | |
"movie_id", | |
"movie_title", | |
"movie_release_year", | |
"movie_url", | |
"movie_title_language", | |
"movie_popularity", | |
"movie_image_url", | |
"director_id", | |
"director_name", | |
"director_url" | |
], | |
"column_comments": [ | |
"", | |
"", | |
"", | |
"", | |
"", | |
"", | |
"", | |
"", | |
"", | |
"" | |
] | |
}, | |
{ | |
"table_name": "ratings_users", | |
"table_comment": "", | |
"column_names": [ | |
"user_id", | |
"rating_date_utc", | |
"user_trialist", | |
"user_subscriber", | |
"user_avatar_image_url", | |
"user_cover_image_url", | |
"user_eligible_for_trial", | |
"user_has_payment_method" | |
], | |
"column_comments": [ | |
"", | |
"", | |
"", | |
"", | |
"", | |
"", | |
"", | |
"" | |
] | |
}, | |
{ | |
"table_name": "lists_users", | |
"table_comment": "", | |
"column_names": [ | |
"user_id", | |
"list_id", | |
"list_update_date_utc", | |
"list_creation_date_utc", | |
"user_trialist", | |
"user_subscriber", | |
"user_avatar_image_url", | |
"user_cover_image_url", | |
"user_eligible_for_trial", | |
"user_has_payment_method" | |
], | |
"column_comments": [ | |
"", | |
"", | |
"", | |
"", | |
"", | |
"", | |
"", | |
"", | |
"", | |
"" | |
] | |
}, | |
{ | |
"table_name": "ratings", | |
"table_comment": "", | |
"column_names": [ | |
"movie_id", | |
"rating_id", | |
"rating_url", | |
"rating_score", | |
"rating_timestamp_utc", | |
"critic", | |
"critic_likes", | |
"critic_comments", | |
"user_id", | |
"user_trialist", | |
"user_subscriber", | |
"user_eligible_for_trial", | |
"user_has_payment_method" | |
], | |
"column_comments": [ | |
"", | |
"", | |
"", | |
"", | |
"", | |
"", | |
"", | |
"", | |
"", | |
"", | |
"", | |
"", | |
"" | |
] | |
} | |
] | |
} | |
} | |
dataset = [data] | |
# 最多保留数据库中的7张表 | |
num_top_k_tables = 7 | |
# 对于每张保留的表,最多保留其中20个列,所以输入的prompt中最多有7*10=70个列 | |
num_top_k_columns = 10 | |
# 加载分类器模型 | |
sic = SchemaItemClassifierInference("sic_merged") | |
# 对于测试数据,我们需要加载训练好的分类器,根据用户问题对表和列打分 | |
dataset = filter_func( | |
dataset = dataset, | |
dataset_type = "eval", | |
sic = sic, | |
num_top_k_tables = num_top_k_tables, | |
num_top_k_columns = num_top_k_columns | |
) | |
print(dataset) |