Spaces:
Paused
Paused
File size: 7,719 Bytes
b759b87 abb320a b759b87 6670a17 9ac5bfc b759b87 9ac5bfc b759b87 9ac5bfc 079d156 9ac5bfc b759b87 abb320a 9ac5bfc b759b87 9ac5bfc b759b87 9ac5bfc b759b87 9ac5bfc b759b87 9ac5bfc b759b87 9ac5bfc b759b87 9ac5bfc b759b87 9ac5bfc b759b87 9ac5bfc b759b87 9ac5bfc b759b87 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 |
import os
import json
import torch
import copy
import re
import sqlparse
import sqlite3
from tqdm import tqdm
from utils.db_utils import get_db_schema
from transformers import AutoModelForCausalLM, AutoTokenizer
from whoosh import index
from whoosh.index import create_in
from whoosh.fields import Schema, TEXT
from whoosh.qparser import QueryParser
from utils.db_utils import check_sql_executability, get_matched_contents, get_db_schema_sequence, get_matched_content_sequence
from schema_item_filter import SchemaItemClassifierInference, filter_schema
def remove_similar_comments(names, comments):
'''
Remove table (or column) comments that have a high degree of similarity with their names
'''
new_comments = []
for name, comment in zip(names, comments):
if name.replace("_", "").replace(" ", "") == comment.replace("_", "").replace(" ", ""):
new_comments.append("")
else:
new_comments.append(comment)
return new_comments
def load_db_comments(table_json_path):
additional_db_info = json.load(open(table_json_path))
db_comments = dict()
for db_info in additional_db_info:
comment_dict = dict()
column_names = [column_name.lower() for _, column_name in db_info["column_names_original"]]
table_idx_of_each_column = [t_idx for t_idx, _ in db_info["column_names_original"]]
column_comments = [column_comment.lower() for _, column_comment in db_info["column_names"]]
assert len(column_names) == len(column_comments)
column_comments = remove_similar_comments(column_names, column_comments)
table_names = [table_name.lower() for table_name in db_info["table_names_original"]]
table_comments = [table_comment.lower() for table_comment in db_info["table_names"]]
assert len(table_names) == len(table_comments)
table_comments = remove_similar_comments(table_names, table_comments)
for table_idx, (table_name, table_comment) in enumerate(zip(table_names, table_comments)):
comment_dict[table_name] = {
"table_comment": table_comment,
"column_comments": dict()
}
for t_idx, column_name, column_comment in zip(table_idx_of_each_column, column_names, column_comments):
if t_idx == table_idx:
comment_dict[table_name]["column_comments"][column_name] = column_comment
db_comments[db_info["db_id"]] = comment_dict
return db_comments
def get_db_id2schema(db_path, tables_json):
db_comments = load_db_comments(tables_json)
db_id2schema = dict()
for db_id in tqdm(os.listdir(db_path)):
db_id2schema[db_id] = get_db_schema(os.path.join(db_path, db_id, db_id + ".sqlite"), db_comments, db_id)
return db_id2schema
def get_db_id2ddl(db_path):
db_ids = os.listdir(db_path)
db_id2ddl = dict()
for db_id in db_ids:
conn = sqlite3.connect(os.path.join(db_path, db_id, db_id + ".sqlite"))
cursor = conn.cursor()
cursor.execute("SELECT name, sql FROM sqlite_master WHERE type='table';")
tables = cursor.fetchall()
ddl = []
for table in tables:
table_name = table[0]
table_ddl = table[1]
table_ddl.replace("\t", " ")
while " " in table_ddl:
table_ddl = table_ddl.replace(" ", " ")
table_ddl = re.sub(r'--.*', '', table_ddl)
table_ddl = sqlparse.format(table_ddl, keyword_case = "upper", identifier_case = "lower", reindent_aligned = True)
table_ddl = table_ddl.replace(", ", ",\n ")
if table_ddl.endswith(";"):
table_ddl = table_ddl[:-1]
table_ddl = table_ddl[:-1] + "\n);"
table_ddl = re.sub(r"(CREATE TABLE.*?)\(", r"\1(\n ", table_ddl)
ddl.append(table_ddl)
db_id2ddl[db_id] = "\n\n".join(ddl)
return db_id2ddl
class ChatBot():
def __init__(self) -> None:
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
model_name = "seeklhy/codes-1b"
# Load tokenizer and model
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype=torch.float16)
# Set the device for the model (this ensures it's on either GPU or CPU)
self.device = self.model.device # This will get the device the model is loaded on (either CUDA or CPU)
# Define other parameters
self.max_length = 4096
self.max_new_tokens = 256
self.max_prefix_length = self.max_length - self.max_new_tokens
# Load the Schema Item Classifier
self.sic = SchemaItemClassifierInference("Roxanne-WANG/LangSQL")
# Initialize searcher for DB content (Whoosh index)
self.db_id2content_searcher = dict()
for db_id in os.listdir("db_contents_index"):
index_dir = os.path.join("db_contents_index", db_id)
if index.exists_in(index_dir):
ix = index.open_dir(index_dir)
self.db_id2content_searcher[db_id] = ix
def get_response(self, question, db_id):
# Prepare the data for schema filtering
data = {
"text": question,
"schema": copy.deepcopy(self.db_id2schema[db_id]),
"matched_contents": get_matched_contents(question, self.db_id2content_searcher[db_id])
}
# Filter schema based on predictions
data = filter_schema(data, self.sic, 6, 10)
data["schema_sequence"] = get_db_schema_sequence(data["schema"])
data["content_sequence"] = get_matched_content_sequence(data["matched_contents"])
# Prepare input sequence for the model
prefix_seq = data["schema_sequence"] + "\n" + data["content_sequence"] + "\n" + data["text"] + "\n"
input_ids = [self.tokenizer.bos_token_id] + self.tokenizer(prefix_seq, truncation=False)["input_ids"]
if len(input_ids) > self.max_prefix_length:
input_ids = [self.tokenizer.bos_token_id] + input_ids[-(self.max_prefix_length-1):]
attention_mask = [1] * len(input_ids)
# Move input tensors to the same device as the model
inputs = {
"input_ids": torch.tensor([input_ids], dtype=torch.int64).to(self.device),
"attention_mask": torch.tensor([attention_mask], dtype=torch.int64).to(self.device)
}
# Generate SQL query using the model
with torch.no_grad():
generate_ids = self.model.generate(
**inputs,
max_new_tokens=self.max_new_tokens,
num_beams=4,
num_return_sequences=4
)
# Decode the generated SQL queries
generated_sqls = self.tokenizer.batch_decode(generate_ids[:, len(input_ids):], skip_special_tokens=True, clean_up_tokenization_spaces=False)
final_generated_sql = None
for generated_sql in generated_sqls:
execution_error = check_sql_executability(generated_sql, os.path.join("databases", db_id, db_id + ".sqlite"))
if execution_error is None:
final_generated_sql = generated_sql
break
if final_generated_sql is None:
if generated_sqls[0].strip() != "":
final_generated_sql = generated_sqls[0].strip()
else:
final_generated_sql = "Sorry, I can not generate a suitable SQL query for your question."
return final_generated_sql.replace("\n", " ") |