SQLQueryShield / README.md
salmane11's picture
Update README.md
ca4118e verified
metadata
library_name: transformers
tags:
  - text-to-SQL
  - SQL
  - code-generation
  - NLQ-to-SQL
  - text2SQL
  - Security
  - Vulnerability detection
datasets:
  - salmane11/SQLShield
language:
  - en
base_model:
  - microsoft/codebert-base

SQLQueryShield

Model Description

SQLQueryShield is a vulnerable SQL query detection model. It classifies SQL queries as either vulnerable (e.g., prone to SQL injection or unsafe execution) or benign (safe to execute).

The checkpoint included in this repository is based on microsoft/codebert-base and further finetuned on SQLShield, a dataset dedicated to text-to-SQL vulnerability detection composed of vulnerable and safe NLQs and their related SQL queries.

Finetuning Procedure

The model was fine-tuned using the Hugging Face Transformers library. The following steps were used:

  1. Dataset: SSQLShield, only the SQL queries from the (NLQ, SQL) pairs were used for training.

  2. Preprocessing:

    • Input Format: Raw SQL query strings.

    • Tokenization: Tokenized using microsoft/codebert-base.

    • Max Length: 128 tokens.

    • Padding and truncation applied.

Intended Use and Limitations

SQLQueryShield is intended for use as a post-generation filter or analysis tool in any system that executes or generates SQL queries. Its main role is to detect whether a SQL query is potentially harmful due to vulnerability patterns such as SQL injection, improper string concatenation, or unsafe expressions.

Ideal use cases:

- Filtering SQL queries in Text-to-SQL applications

- Post-processing or validating user-generated SQL before execution

How to Use

Example 1: Malicious

from transformers import pipeline

sql_query_shield = pipeline("text-classification", model="salmane11/SQLQueryShield")

# For the following Table schema
# CREATE TABLE campuses
#   (
#      campus   VARCHAR,
#      location VARCHAR
#   )

query = "SELECT campus FROM campuses WHERE location = '' UNION SELECT database() --"

prediction = sql_query_shield(query)
print(prediction)
#[{'label': 'MALICIOUS', 'score': 0.9995294809341431}]

Example 2: Safe

from transformers import pipeline

sql_query_shield = pipeline("text-classification", model="salmane11/SQLQueryShield")

# For the following Table schema
# CREATE TABLE tv_channel
#   (
#      package_option VARCHAR,
#      series_name    VARCHAR
#   ) 

query = "SELECT package_option FROM tv_channel WHERE series_name = 'Sky Radio'"


prediction = sql_query_shield(query)
print(prediction)
#[{'label': 'SAFE', 'score': 0.999503493309021}]

Cite our work

Citation