import re from multiprocessing import cpu_count from keras.src.saving import load_model import pandas as pd from keras.src.utils import set_random_seed from numpy import int64 from pandarallel import pandarallel from sklearn.preprocessing import RobustScaler import gradio as gr set_random_seed(65536) pandarallel.initialize(use_memory_fs=True, nb_workers=cpu_count()) model = load_model('./sqid.keras') def sql_tokenize(sql_query): sql_query = sql_query.replace('`', ' ').replace('%20', ' ').replace('=', ' = ').replace('((', ' (( ').replace( '))', ' )) ').replace('(', ' ( ').replace(')', ' ) ').replace('||', ' || ').replace(',', '').replace( '--', ' -- ').replace(':', ' : ').replace('%23', ' # ').replace('+', ' + ').replace('!=', ' != ') \ .replace('"', ' " ').replace('%26', ' and ').replace('$', ' $ ').replace('%28', ' ( ').replace('%2A', ' * ') \ .replace('%7C', ' | ').replace('&', ' & ').replace(']', ' ] ').replace('[', ' [ ').replace(';', ' ; ').replace( '/*', ' /* ') sql_reserved = {'SELECT', 'FROM', 'WHERE', 'AND', 'OR', 'NOT', 'IN', 'LIKE', 'ORDER', 'BY', 'GROUP', 'HAVING', 'LIMIT', 'BETWEEN', 'IS', 'NULL', '%', 'LIKE', 'MIN', 'MAX', 'AS', 'UPPER', 'LOWER', 'TO_DATE', '=', '>', '<', '>=', '<=', '!=', '<>', 'BETWEEN', 'LIKE', 'EXISTS', 'JOIN', 'UNION', 'ALL', 'ASC', 'DESC', '||', 'AVG', 'LIMIT', 'EXCEPT', 'INTERSECT', 'CASE', 'WHEN', 'THEN', 'IF', 'IF', 'ANY', 'CAST', 'CONVERT', 'COALESCE', 'NULLIF', 'INNER', 'OUTER', 'LEFT', 'RIGHT', 'FULL', 'CROSS', 'OVER', 'PARTITION', 'SUM', 'COUNT', 'WITH', 'INTERVAL', 'WINDOW', 'OVER', 'ROW_NUMBER', 'RANK', 'DENSE_RANK', 'NTILE', 'FIRST_VALUE', 'LAST_VALUE', 'LAG', 'LEAD', 'DISTINCT', 'COMMENT', 'INSERT', 'UPDATE', 'DELETED', 'MERGE', '*', 'generate_series', 'char', 'chr', 'substr', 'lpad', 'extract', 'year', 'month', 'day', 'timestamp', 'number', 'string', 'concat', 'INFORMATION_SCHEMA', "SQLITE_MASTER", 'TABLES', 'COLUMNS', 'CUBE', 'ROLLUP', 'RECURSIVE', 'FILTER', 'EXCLUDE', 'AUTOINCREMENT', 'WITHOUT', 'ROWID', 'VIRTUAL', 'INDEXED', 'UNINDEXED', 'SERIAL', 'DO', 'RETURNING', 'ILIKE', 'ARRAY', 'ANYARRAY', 'JSONB', 'TSQUERY', 'SEQUENCE', 'SYNONYM', 'CONNECT', 'START', 'LEVEL', 'ROWNUM', 'NOCOPY', 'MINUS', 'AUTO_INCREMENT', 'BINARY', 'ENUM', 'REPLACE', 'SET', 'SHOW', 'DESCRIBE', 'USE', 'EXPLAIN', 'STORED', 'VIRTUAL', 'RLIKE', 'MD5', 'SLEEP', 'BENCHMARK', '@@VERSION', 'VERSION', '@VERSION', 'CONVERT', 'NVARCHAR', '#', '##', 'INJECTX', 'DELAY', 'WAITFOR', 'RAND', } tokens = sql_query.split() tokens = [re.sub(r"""[^*\w\s.=\-><_|()!"']""", '', token) for token in tokens] for i, token in enumerate(tokens): if token.strip().upper() in sql_reserved: continue if token.strip().isnumeric(): tokens[i] = '#NUMBER#' elif re.match(r'^[a-zA-Z_.|][a-zA-Z0-9_.|]*$', token.strip()): tokens[i] = '#IDENTIFIER#' elif re.match(r'^[\d:]*$', token.strip()): tokens[i] = '#TIMESTAMP#' elif '%' in token.strip(): tokens[i] = ' '.join( [j.strip() if j.strip() in ('%', "'", "'") else '#IDENTIFIER#' for j in token.strip().split('%')]) return ' '.join(tokens) def add_features(x): x['Query'] = x['Query'].copy().parallel_apply(lambda a: sql_tokenize(a)) x['num_tables'] = x['Query'].str.lower().str.count(r'FROM\s+#IDENTIFIER#', flags=re.I) x['num_columns'] = x['Query'].str.lower().str.count(r'SELECT\s+#IDENTIFIER#', flags=re.I) x['num_literals'] = x['Query'].str.lower().str.count("'[^']*'", flags=re.I) + x['Query'].str.lower().str.count( '"[^"]"', flags=re.I) x['num_parentheses'] = x['Query'].str.lower().str.count("\\(", flags=re.I) + x['Query'].str.lower().str.count( '\\)', flags=re.I) x['has_union'] = x['Query'].str.lower().str.count(" union |union all", flags=re.I) > 0 x['has_union'] = x['has_union'].astype(int64) x['depth_nested_queries'] = x['Query'].str.lower().str.count("\\(", flags=re.I) x['num_join'] = x['Query'].str.lower().str.count( " join |inner join|outer join|full outer join|full inner join|cross join|left join|right join", flags=re.I) x['num_sp_chars'] = x['Query'].parallel_apply(lambda a: len(re.findall(r'[\'";\-*/%=><|#]', a))) x['has_mismatched_quotes'] = x['Query'].parallel_apply( lambda sql_query: 1 if re.search(r"'.*[^']$|\".*[^\"]$", sql_query) else 0) x['has_tautology'] = x['Query'].parallel_apply(lambda sql_query: 1 if re.search(r"'[\s]*=[\s]*'", sql_query) else 0) return x def is_malicious_sql(sql, threshold): input_df = pd.DataFrame([sql], columns=['Query']) input_df = add_features(input_df) numeric_features = ["num_tables", "num_columns", "num_literals", "num_parentheses", "has_union", "depth_nested_queries", "num_join", "num_sp_chars", "has_mismatched_quotes", "has_tautology"] scaler = RobustScaler() x_in = scaler.fit_transform(input_df[numeric_features]) preds = model.predict([input_df['Query'], x_in]).tolist()[0][0] if preds > float(threshold): return f'Malicious - {preds}' return f'Safe - {preds}' def respond( message, history, threshold ): if len(history) > 5: history = history[1:] for val in history: if val[0].lower().strip() == message.lower().strip(): return val[1] val = (message.lower().strip(), is_malicious_sql(message, threshold)) print(val) return val[1] """ For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface """ demo = gr.ChatInterface( respond, title='SafeSQL-v1-Demo', description='Please enter a SQL query as your input. You may adjust the minimum probability threshold for reporting SQLs as malicious using the slider below.', additional_inputs=[ gr.Slider(minimum=0.01, maximum=0.99, value=0.75, step=0.01, label="Detection Probability Threshold "), ], ) if __name__ == "__main__": demo.launch()