schema_filter / training_mode.py
justinsiow's picture
Uploaded Utils, Pycache and Python Files
1e712af verified
raw
history blame contribute delete
No virus
4.63 kB
from schema_filter import filter_func
data = {
"text": "Name movie titles released in year 1945. Sort the listing by the descending order of movie popularity.",
"sql": "SELECT movie_title FROM movies WHERE movie_release_year = 1945 ORDER BY movie_popularity DESC LIMIT 1",
"schema": {
"schema_items": [
{
"table_name": "lists",
"table_comment": "",
"column_names": [
"user_id",
"list_id",
"list_title",
"list_movie_number",
"list_update_timestamp_utc",
"list_creation_timestamp_utc",
"list_followers",
"list_url",
"list_comments",
"list_description",
"list_cover_image_url",
"list_first_image_url",
"list_second_image_url",
"list_third_image_url"
],
"column_comments": [
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
""
]
},
{
"table_name": "movies",
"table_comment": "",
"column_names": [
"movie_id",
"movie_title",
"movie_release_year",
"movie_url",
"movie_title_language",
"movie_popularity",
"movie_image_url",
"director_id",
"director_name",
"director_url"
],
"column_comments": [
"",
"",
"",
"",
"",
"",
"",
"",
"",
""
]
},
{
"table_name": "ratings_users",
"table_comment": "",
"column_names": [
"user_id",
"rating_date_utc",
"user_trialist",
"user_subscriber",
"user_avatar_image_url",
"user_cover_image_url",
"user_eligible_for_trial",
"user_has_payment_method"
],
"column_comments": [
"",
"",
"",
"",
"",
"",
"",
""
]
},
{
"table_name": "lists_users",
"table_comment": "",
"column_names": [
"user_id",
"list_id",
"list_update_date_utc",
"list_creation_date_utc",
"user_trialist",
"user_subscriber",
"user_avatar_image_url",
"user_cover_image_url",
"user_eligible_for_trial",
"user_has_payment_method"
],
"column_comments": [
"",
"",
"",
"",
"",
"",
"",
"",
"",
""
]
},
{
"table_name": "ratings",
"table_comment": "",
"column_names": [
"movie_id",
"rating_id",
"rating_url",
"rating_score",
"rating_timestamp_utc",
"critic",
"critic_likes",
"critic_comments",
"user_id",
"user_trialist",
"user_subscriber",
"user_eligible_for_trial",
"user_has_payment_method"
],
"column_comments": [
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
"",
""
]
}
]
}
}
def find_used_tables_and_columns(dataset):
for data in dataset:
sql = data["sql"].lower()
data["table_labels"] = []
data["column_labels"] = []
for table_info in data["schema"]["schema_items"]:
table_name = table_info["table_name"]
data["table_labels"].append(1 if table_name.lower() in sql else 0)
data["column_labels"].append([1 if column_name.lower() in sql else 0 \
for column_name in table_info["column_names"]])
return dataset
dataset = [data]
# 根据sql找到用到的表和列
dataset = find_used_tables_and_columns(dataset)
# 最多保留数据库中的6张表
num_top_k_tables = 6
# 对于每张保留的表,最多保留其中6个列,所以输入的prompt中最多有6*6=36个列
num_top_k_columns = 6
# 对于训练数据,我们可以根据sql来模拟filter的过程,这时,sic(schema item classifier)是None就行,不需要用到模型
dataset = filter_func(
dataset = dataset,
dataset_type = "train",
sic = None,
num_top_k_tables = num_top_k_tables,
num_top_k_columns = num_top_k_columns
)
print(dataset)