Spaces:
Sleeping
Sleeping
File size: 14,980 Bytes
81c8729 4af55c1 81c8729 4af55c1 81c8729 4af55c1 81c8729 4af55c1 81c8729 4af55c1 81c8729 4af55c1 81c8729 4af55c1 81c8729 4af55c1 81c8729 fedc4fb 81c8729 fedc4fb 81c8729 fedc4fb 4af55c1 81c8729 4af55c1 81c8729 4af55c1 81c8729 4af55c1 81c8729 4af55c1 81c8729 4af55c1 81c8729 4af55c1 81c8729 4af55c1 46a5649 81c8729 46a5649 81c8729 46a5649 81c8729 46a5649 81c8729 46a5649 81c8729 46a5649 81c8729 46a5649 4af55c1 81c8729 4af55c1 81c8729 4af55c1 f94f145 4af55c1 f94f145 4af55c1 f94f145 4af55c1 81c8729 f9aa49d 81c8729 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 |
import streamlit as st
import tensorflow as tf
from tensorflow.keras.models import load_model
from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.preprocessing.sequence import pad_sequences
import pickle
import re
import time
import numpy as np
from sklearn.ensemble import RandomForestClassifier
from sklearn.svm import SVC
# Load models and preprocessing components
@st.cache_resource
def load_components():
# Load deep learning models
cnn_model = load_model('cnn_model.h5')
lstm_model = load_model('lstm_model.h5')
# Load traditional ML models
with open('rf_model.pkl', 'rb') as f:
rf_model = pickle.load(f)
with open('svm_model.pkl', 'rb') as f:
svm_model = pickle.load(f)
# Load tokenizer and vectorizer
with open('sql_tokenizer.pkl', 'rb') as f:
tokenizer_data = pickle.load(f)
with open('tfidf_vectorizer.pkl', 'rb') as f:
tfidf_vectorizer = pickle.load(f)
return {
'cnn_model': cnn_model,
'lstm_model': lstm_model,
'rf_model': rf_model,
'svm_model': svm_model,
'tokenizer': tokenizer_data['tokenizer'],
'max_sequence_length': tokenizer_data['max_sequence_length'],
'tfidf_vectorizer': tfidf_vectorizer
}
# Try to load all components
try:
components = load_components()
model_loading_error = None
except Exception as e:
model_loading_error = str(e)
components = None
# Preprocess functions
def preprocess_query_for_deep_learning(query, tokenizer, max_sequence_length):
sequences = tokenizer.texts_to_sequences([query])
padded = pad_sequences(sequences, maxlen=max_sequence_length, padding='post')
return padded
def preprocess_query_for_traditional_ml(query, tfidf_vectorizer):
return tfidf_vectorizer.transform([query])
# Define improved regex patterns for SQL injection attempts
SQL_INJECTION_PATTERNS = [
# SQL comment syntax that follows a quote (likely injection)
r"(?i)'.*--",
# Quote followed by OR/AND with comparison (classic injection pattern)
r"(?i)'\s*(OR|AND)\s*['\d\w]+=\s*['\d\w]+",
# SQL Comment without preceding from a query context
r"(?i)(\s|^)--",
# Multiple query execution with semicolon
r"(?i)'.*;.*--",
# UNION-based injections
r"(?i)'\s*UNION\s+(ALL\s+)?SELECT",
# Time-delay attacks
r"(?i)'\s*;\s*WAITFOR\s+DELAY",
# DROP/ALTER table attacks
r"(?i)'\s*;\s*(DROP|ALTER)",
# Quote followed by a true condition
r"(?i)'\s*OR\s*'?\d+'?\s*=\s*'?\d+'?",
# Quote followed by always true condition like 1=1
r"(?i)'\s*OR\s*(['\"]\d+['\"])=(['\"]\d+['\"])",
# Batch queries
r"(?i);\s*(SELECT|INSERT|UPDATE|DELETE|DROP)",
# CAST attacks
r"(?i)CAST\s*\(.+AS\s+.+\)",
# Typical SQL function calls in injections
r"(?i)'\s*;\s*(EXEC|EXECUTE).*",
]
# Safe SQL patterns that should not trigger false positives
SAFE_SQL_PATTERNS = [
# Standard SELECT query
r"(?i)^SELECT\s+[\w\d\s,*]+\s+FROM\s+[\w\d]+(\s+WHERE\s+[\w\d\s=<>']+)?$",
# Standard INSERT query
r"(?i)^INSERT\s+INTO\s+[\w\d]+\s*\([^)]+\)\s*VALUES\s*\([^)]+\)$",
# Standard UPDATE query
r"(?i)^UPDATE\s+[\w\d]+\s+SET\s+[\w\d\s=',]+(\s+WHERE\s+[\w\d\s=<>']+)?$",
]
# Rule-based detection function
def detect_sql_injection_with_regex(query):
for pattern in SAFE_SQL_PATTERNS:
if re.search(pattern, query.strip()):
return False, None
for pattern in SQL_INJECTION_PATTERNS:
match = re.search(pattern, query)
if match:
return True, match.group(0)
return False, None
# Ensemble prediction function
def predict_with_ensemble(query, components):
# Random Forest prediction
query_tfidf = preprocess_query_for_traditional_ml(query, components['tfidf_vectorizer'])
rf_pred = int(components['rf_model'].predict(query_tfidf)[0])
# SVM prediction
svm_pred = int(components['svm_model'].predict(query_tfidf)[0])
# CNN prediction
query_padded = preprocess_query_for_deep_learning(query, components['tokenizer'], components['max_sequence_length'])
cnn_probability = components['cnn_model'].predict(query_padded)[0][0]
cnn_pred = int(cnn_probability > 0.5)
# LSTM prediction
lstm_probability = components['lstm_model'].predict(query_padded)[0][0]
lstm_pred = int(lstm_probability > 0.5)
# Count votes
votes = [rf_pred, svm_pred, cnn_pred, lstm_pred]
vote_count = {0: votes.count(0), 1: votes.count(1)}
return {
'rf': rf_pred,
'svm': svm_pred,
'cnn': {'prediction': cnn_pred, 'probability': float(cnn_probability)},
'lstm': {'prediction': lstm_pred, 'probability': float(lstm_probability)},
'vote_count': vote_count
}
# Initialize session state
if 'analysis_stage' not in st.session_state:
st.session_state.analysis_stage = 0
if 'regex_result' not in st.session_state:
st.session_state.regex_result = None
if 'ensemble_result' not in st.session_state:
st.session_state.ensemble_result = None
# App title and description
st.title("🛡️ SQL Injection Detection")
st.markdown("""
This application uses a multi-layered approach to detect potentially malicious SQL queries:
1. **Rule-based detection** using improved regex patterns.
2. **Ensemble learning** with majority voting from 4 models:
- Random Forest
- Support Vector Machine
- Convolutional Neural Network
- Long Short-Term Memory Network.
""")
# Display warning if models couldn't be loaded
if model_loading_error:
st.warning(f"⚠️ Some models could not be loaded. The application will only use rule-based detection. Error: {model_loading_error}")
# Example queries in a dropdown
example_categories = {
"Benign SQL Queries": [
"SELECT * FROM users WHERE username='admin'",
"SELECT id, name, price FROM products WHERE category_id=5",
"SELECT COUNT(*) FROM orders WHERE date > '2023-01-01'",
"INSERT INTO logs (user_id, action) VALUES (42, 'login')",
"UPDATE customers SET last_login='2023-06-15' WHERE id=101",
"DELETE FROM sessions WHERE last_activity < '2023-01-01'",
"SELECT email FROM subscribers WHERE active=1",
"INSERT INTO feedback (user_id, message) VALUES (87, 'Great service!')",
"UPDATE inventory SET stock = stock - 1 WHERE product_id = 300",
"SELECT name FROM employees WHERE department = 'Sales'",
"SELECT AVG(rating) FROM reviews WHERE product_id = 55",
"INSERT INTO audit_log (timestamp, event) VALUES (CURRENT_TIMESTAMP, 'update')",
"SELECT * FROM appointments WHERE doctor_id = 10 AND status = 'confirmed'",
"UPDATE settings SET value='dark' WHERE key='theme'",
"SELECT DISTINCT city FROM customers WHERE country='USA'",
"DELETE FROM cart_items WHERE user_id=12 AND product_id=78",
"SELECT MAX(salary) FROM employees WHERE role='manager'",
"INSERT INTO payments (user_id, amount, method) VALUES (33, 99.99, 'credit')",
"UPDATE products SET price = price * 1.1 WHERE category_id = 7",
"SELECT * FROM messages WHERE sender_id = 5 AND is_read = 0"
],
"Malicious SQL Queries": [
"' OR 1=1 --",
"admin'; DROP TABLE users; --",
"SELECT * FROM users WHERE username='' UNION SELECT username,password FROM admin_users --",
"'; WAITFOR DELAY '0:0:10' --",
"admin' OR '1'='1",
"' OR 'a'='a",
"' OR 1=1#",
"' OR 1=1/*",
"admin'--",
"'; EXEC xp_cmdshell('dir'); --",
"' OR EXISTS(SELECT * FROM users WHERE username = 'admin') --",
"1; DROP TABLE sessions --",
"'; SHUTDOWN --",
"' OR SLEEP(5) --",
"' AND 1=(SELECT COUNT(*) FROM users) --",
"admin' AND SUBSTRING(password, 1, 1) = 'a' --",
"' UNION ALL SELECT NULL,NULL,NULL --",
"0' OR 1=1 ORDER BY 1 --",
"1' AND (SELECT COUNT(*) FROM users) > 0 --",
"' OR (SELECT ASCII(SUBSTRING(password,1,1)) FROM users WHERE username='admin') > 64 --"
]
}
category = st.selectbox("Choose query category:", options=list(example_categories.keys()))
example = st.selectbox("Select an example:", options=example_categories[category])
query_source = st.radio("Query source:", ["Use selected example", "Enter my own query"])
query = example if query_source == "Use selected example" else st.text_area("Enter SQL Query:", placeholder="Type your SQL query here...")
# Analysis process
if st.button("Start Analysis") and query:
st.session_state.analysis_stage = 1
with st.spinner("Running rule-based detection..."):
time.sleep(0.5) # Simulate processing time
is_malicious, matched_pattern = detect_sql_injection_with_regex(query)
st.session_state.regex_result = (is_malicious, matched_pattern)
# Rule-based analysis results
if st.session_state.analysis_stage >= 1 and st.session_state.regex_result is not None:
is_malicious, matched_pattern = st.session_state.regex_result
st.subheader("Step 1: Rule-Based Detection")
if is_malicious:
st.error("🚨 SQL Injection Detected (Rule-Based)!")
st.warning(f"Matched pattern: `{matched_pattern}`")
else:
st.success("✅ No SQL injection patterns detected using rules")
proceed = st.radio("Proceed with ensemble detection?", ["Yes", "No"], index=0)
if proceed == "Yes" and not model_loading_error:
if st.button("Run Ensemble Analysis"):
st.session_state.analysis_stage = 2
with st.spinner("Running ensemble models..."):
time.sleep(1) # Simulate processing time
ensemble_results = predict_with_ensemble(query, components)
st.session_state.ensemble_result = ensemble_results
# Ensemble analysis results
if st.session_state.analysis_stage >= 2 and st.session_state.ensemble_result is not None:
results = st.session_state.ensemble_result
st.subheader("Step 2: Ensemble Model Detection")
vote_benign = results['vote_count'][0]
vote_malicious = results['vote_count'][1]
# Create columns for voting visualization
col1, col2 = st.columns(2)
with col1:
st.metric("Safe Votes", vote_benign)
with col2:
st.metric("Malicious Votes", vote_malicious)
# Progress bar for malicious ratio
vote_ratio = vote_malicious / (vote_benign + vote_malicious)
st.progress(vote_ratio, text=f"Malicious vote ratio: {vote_ratio*100:.0f}%")
# Display individual model results
st.markdown("### Individual Model Results")
model_cols = st.columns(4)
with model_cols[0]:
st.markdown("**Random Forest**")
if results['rf'] == 1:
st.error("⚠️ Malicious")
else:
st.success("✅ Safe")
with model_cols[1]:
st.markdown("**SVM**")
if results['svm'] == 1:
st.error("⚠️ Malicious")
else:
st.success("✅ Safe")
with model_cols[2]:
st.markdown("**CNN**")
cnn_prob = results['cnn']['probability'] * 100
if results['cnn']['prediction'] == 1:
st.error(f"⚠️ Malicious ({cnn_prob:.1f}%)")
else:
st.success(f"✅ Safe ({100-cnn_prob:.1f}%)")
with model_cols[3]:
st.markdown("**LSTM**")
lstm_prob = results['lstm']['probability'] * 100
if results['lstm']['prediction'] == 1:
st.error(f"⚠️ Malicious ({lstm_prob:.1f}%)")
else:
st.success(f"✅ Safe ({100-lstm_prob:.1f}%)")
# Final ensemble verdict
st.markdown("### Ensemble Verdict")
if vote_benign > 3:
st.success("✅ Query deemed safe by majority vote (>3 safe votes)")
elif vote_malicious > 3:
st.error("🚨 SQL Injection Detected by Majority Vote (>3 malicious votes)")
else:
st.warning("⚠️ Ambiguous result: Votes split (≤3 each). Please cross-check manually.")
# Final verdict combining both approaches
st.subheader("Final Analysis")
is_malicious_regex, _ = st.session_state.regex_result
is_malicious_ensemble = vote_malicious > 3
if is_malicious_regex or is_malicious_ensemble:
st.error("⚠️ This query appears malicious. Review immediately!")
elif vote_benign > 3:
st.success("✅ Query appears safe based on multi-layer analysis")
else:
st.warning("⚠️ Ambiguous result - manual verification required")
if st.button("Analyze Another Query"):
st.session_state.analysis_stage = 0
st.session_state.regex_result = None
st.session_state.ensemble_result = None
st.rerun()
# Sidebar with additional info
with st.sidebar:
st.header("About This App")
st.markdown("""
### Multi-Layer Detection Process
1. **Rule-Based Detection**
- Fast, pattern-matching approach
- Uses improved regex to identify SQL injection patterns
- Reduces false positives with safe pattern recognition
2. **Ensemble Detection**
- Combines 4 different machine learning models:
- Random Forest
- Support Vector Machine (SVM)
- Convolutional Neural Network (CNN)
- Long Short-Term Memory Network (LSTM)
- Final decision by majority voting
""")
st.markdown("### Machine Learning Architecture")
st.code("""
# Traditional ML
- Random Forest (n_estimators=100)
- SVM (kernel='linear')
# CNN Architecture
Sequential([
Embedding(input_dim=10000, output_dim=128),
Conv1D(filters=64, kernel_size=3, activation='relu'),
MaxPooling1D(pool_size=2),
Dropout(0.5),
Conv1D(filters=128, kernel_size=3, activation='relu'),
MaxPooling1D(pool_size=2),
Flatten(),
Dense(64, activation='relu'),
Dropout(0.5),
Dense(1, activation='sigmoid')
])
# LSTM Architecture
Sequential([
Embedding(input_dim=10000, output_dim=128),
Bidirectional(LSTM(64, return_sequences=True)),
Dropout(0.5),
Bidirectional(LSTM(32)),
Dropout(0.5),
Dense(32, activation='relu'),
Dense(1, activation='sigmoid')
])
""")
st.markdown("### How It Works")
st.markdown("""
1. **Step 1:** Rule-based patterns scan for known SQL injection techniques
2. **Step 2:** Ensemble of 4 models evaluates the query structure
3. **Final Analysis:** Combined verdict from both approaches
""")
st.markdown("---")
st.warning("**Note:** This is a demonstration tool, not a replacement for proper security measures.")
# Footer
st.markdown("---")
st.markdown("""
<style>
.footer {
position: fixed;
left: 0;
bottom: 0;
width: 100%;
background-color: white;
color: black;
text-align: center;
padding: 10px;
border-top: 1px solid #e5e5e5;
}
</style>
<div class="footer">
<p>Developed with ❤️ using Streamlit | SQL Injection Detection System</p>
</div>
""", unsafe_allow_html=True) |