Spaces:
Running
Running
import streamlit as st | |
import re | |
import math | |
import time | |
import os | |
import joblib | |
import numpy as np | |
import pandas as pd | |
import torch | |
import tensorflow as tf | |
from collections import Counter | |
from tensorflow.keras.models import load_model | |
from tensorflow.keras.preprocessing.sequence import pad_sequences | |
import tldextract | |
from rapidfuzz import fuzz, process | |
# Set page config | |
st.set_page_config( | |
page_title="URL Threat Detector", | |
page_icon="🛡️", | |
layout="wide", | |
initial_sidebar_state="expanded" | |
) | |
# Disable GPU usage for TensorFlow and PyTorch | |
os.environ['CUDA_VISIBLE_DEVICES'] = '-1' | |
tf.config.set_visible_devices([], 'GPU') | |
# Global configuration | |
MAX_LEN = 200 | |
FIXED_FEATURE_COLS = [ | |
'url_length', 'domain_length', 'subdomain_count', 'path_depth', | |
'param_count', 'has_ip', 'has_executable', 'has_double_extension', | |
'hex_encoded', 'digit_ratio', 'special_char_ratio', 'entropy', | |
'is_safe_domain', 'is_uncommon_tld' | |
] | |
# Enhanced domain and TLD lists | |
SAFE_DOMAINS = { | |
'google.com', 'google.co.in', 'google.co.uk', 'google.fr', 'google.de', | |
'amazon.com', 'amazon.in', 'amazon.co.uk', 'amazon.de', 'amazon.fr', | |
'wikipedia.org', 'github.com', 'python.org', 'irs.gov', 'adobe.com', | |
'steampowered.com', 'imdb.com', 'weather.com', 'archive.org', 'cdc.gov', | |
'microsoft.com', 'apple.com', 'youtube.com', 'facebook.com', 'twitter.com', | |
'linkedin.com', 'instagram.com', 'netflix.com', 'reddit.com', 'stackoverflow.com', | |
'google.com', 'amazon.in', 'linkedin.com' | |
} | |
COMMON_TLDS = { | |
'com', 'org', 'net', 'gov', 'edu', 'mil', 'co', 'io', 'ai', 'in', | |
'uk', 'us', 'ca', 'au', 'de', 'fr', 'es', 'it', 'nl', 'jp', 'cn', | |
'br', 'mx', 'ru', 'ch', 'se', 'no', 'dk', 'fi', 'be', 'at', 'nz' | |
} | |
# Initialize tldextract | |
tld_extractor = tldextract.TLDExtract() | |
def load_char_mapping(): | |
char_to_idx_path = 'char_to_idx.pkl' | |
if not os.path.exists(char_to_idx_path): | |
st.error(f"Character mapping file not found: {char_to_idx_path}") | |
return None | |
return joblib.load(char_to_idx_path) | |
def load_all_models(): | |
"""Load models with CPU optimization""" | |
models = {} | |
model_dir = "models" | |
if not os.path.exists(model_dir): | |
os.makedirs(model_dir) | |
st.warning(f"Created model directory: {model_dir}") | |
# Hybrid models | |
hybrid_models = { | |
'hybrid': 'hybrid_model.h5', | |
'hybrid_fold1': 'best_hybrid_fold1.h5', | |
'hybrid_fold2': 'best_hybrid_fold2.h5' | |
} | |
for name, file in hybrid_models.items(): | |
path = os.path.join(model_dir, file) | |
if os.path.exists(path): | |
try: | |
models[name] = load_model(path) | |
st.success(f"Loaded {name}") | |
except Exception as e: | |
st.error(f"Error loading {name}: {str(e)}") | |
else: | |
st.warning(f"Model file not found: {path}") | |
# Traditional models | |
traditional_models = { | |
'random_forest': 'random_forest_model.pkl', | |
'xgboost': 'xgboost_model.pkl', | |
} | |
for name, file in traditional_models.items(): | |
path = os.path.join(model_dir, file) | |
if os.path.exists(path): | |
try: | |
models[name] = joblib.load(path) | |
st.success(f"Loaded {name}") | |
except Exception as e: | |
st.error(f"Error loading {name}: {str(e)}") | |
else: | |
st.warning(f"Model file not found: {path}") | |
return models | |
def normalize_url(url): | |
"""Normalize URL with proper indentation and parenthesis""" | |
try: | |
is_https = url.lower().startswith('https://') | |
url = url.lower() | |
prefixes = ['http://', 'ftp://', 'www.', 'ww2.', 'web.'] | |
for prefix in prefixes: | |
if url.startswith(prefix): | |
url = url[len(prefix):] | |
if is_https: | |
url = "https://" + url | |
url = url.split('#')[0] | |
if '?' in url: | |
base, query = url.split('?', 1) | |
if not any(sd in base for sd in SAFE_DOMAINS): | |
params = [p for p in query.split('&') if '=' in p] | |
essential_params = [p for p in params if any( | |
kw in p for kw in ['id=', 'ref=', 'token='])] | |
url = base + ('?' + '&'.join(essential_params) if essential_params else '' | |
return re.sub(r'/{2,}', '/', url) | |
except Exception: | |
return url | |
def extract_url_components(url): | |
"""Robust URL parsing""" | |
try: | |
extracted = tld_extractor(url) | |
subdomain = extracted.subdomain | |
domain = extracted.domain | |
suffix = extracted.suffix | |
path = "" | |
query = "" | |
if "/" in url: | |
path_start = url.find("/", url.find("//") + 2) if "//" in url else url.find("/") | |
if path_start != -1: | |
path_query = url[path_start:] | |
if "?" in path_query: | |
path, query = path_query.split("?", 1) | |
else: | |
path = path_query | |
if not domain and subdomain: | |
domain_parts = subdomain.split('.') | |
if len(domain_parts) > 1: | |
domain = domain_parts[-1] | |
subdomain = '.'.join(domain_parts[:-1]) | |
return { | |
'subdomain': subdomain, | |
'domain': domain, | |
'suffix': suffix, | |
'path': path, | |
'query': query | |
} | |
except: | |
return { | |
'subdomain': '', | |
'domain': '', | |
'suffix': '', | |
'path': '', | |
'query': '' | |
} | |
def calculate_entropy(s): | |
"""Compute Shannon entropy""" | |
if not s: | |
return 0 | |
try: | |
p, lns = Counter(s), float(len(s)) | |
return -sum(count/lns * math.log(count/lns, 2) for count in p.values()) | |
except: | |
return 0 | |
def fuzzy_domain_match(domain): | |
"""Safe domain matching""" | |
if domain in SAFE_DOMAINS: | |
return True | |
domain_parts = domain.split('.') | |
if len(domain_parts) > 2: | |
base_domain = '.'.join(domain_parts[-2:]) | |
if base_domain in SAFE_DOMAINS: | |
return True | |
best_match, score, _ = process.extractOne(domain, SAFE_DOMAINS, scorer=fuzz.WRatio) | |
return score > 85 | |
def extract_robust_features(url): | |
"""Feature extraction optimized for CPU""" | |
try: | |
clean_url = re.sub(r'[^\x00-\x7F]+', '', str(url)) | |
normalized = normalize_url(clean_url) | |
components = extract_url_components(clean_url) | |
full_domain = f"{components['domain']}.{components['suffix']}" if components['suffix'] else components['domain'] | |
# Structural features | |
url_length = len(clean_url) | |
domain_length = len(components['domain']) | |
subdomain_count = len(components['subdomain'].split('.')) if components['subdomain'] else 0 | |
path_depth = components['path'].count('/') if components['path'] else 0 | |
param_count = len(components['query'].split('&')) if components['query'] else 0 | |
# Security features | |
has_ip = 1 if re.match(r'^\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}$', components['domain']) else 0 | |
has_executable = 1 if re.search(r'\.(exe|js|jar|bat|sh|py|dll)$', components['path'], re.I) else 0 | |
has_double_extension = 1 if re.search(r'\.\w+\.\w+$', components['path'], re.I) else 0 | |
hex_encoded = 1 if re.search(r'%[0-9a-f]{2}', normalized, re.I) else 0 | |
# Lexical features | |
digit_count = sum(c.isdigit() for c in normalized) | |
special_chars = sum(not (c.isalnum() or c in ' ./-') for c in normalized) | |
digit_ratio = digit_count / url_length if url_length > 0 else 0 | |
special_char_ratio = special_chars / url_length if url_length > 0 else 0 | |
entropy = calculate_entropy(normalized) | |
# Domain reputation | |
is_safe_domain = 1 if fuzzy_domain_match(full_domain) else 0 | |
is_uncommon_tld = 1 if components['suffix'] and components['suffix'] not in COMMON_TLDS else 0 | |
if url.startswith('https://') and full_domain in SAFE_DOMAINS: | |
is_safe_domain = 1 | |
return { | |
'url_length': url_length, | |
'domain_length': domain_length, | |
'subdomain_count': subdomain_count, | |
'path_depth': path_depth, | |
'param_count': param_count, | |
'has_ip': has_ip, | |
'has_executable': has_executable, | |
'has_double_extension': has_double_extension, | |
'hex_encoded': hex_encoded, | |
'digit_ratio': digit_ratio, | |
'special_char_ratio': special_char_ratio, | |
'entropy': entropy, | |
'is_safe_domain': is_safe_domain, | |
'is_uncommon_tld': is_uncommon_tld | |
} | |
except Exception as e: | |
st.error(f"Feature extraction error: {str(e)}") | |
return {col: 0 for col in FIXED_FEATURE_COLS} | |
def preprocess_url(url, char_to_idx): | |
"""URL preprocessing for CPU""" | |
try: | |
clean_url = re.sub(r'[^\x00-\x7F]+', '', str(url)) | |
normalized = normalize_url(clean_url) | |
features = extract_robust_features(clean_url) | |
feature_vector = np.array([features.get(col, 0) for col in FIXED_FEATURE_COLS]).reshape(1, -1) | |
char_seq = [char_to_idx.get(c, 0) for c in normalized] | |
char_seq = pad_sequences([char_seq], maxlen=MAX_LEN, padding='post', truncating='post') | |
return char_seq, feature_vector, features | |
except Exception as e: | |
st.error(f"Preprocessing error: {str(e)}") | |
return np.zeros((1, MAX_LEN)), np.zeros((1, len(FIXED_FEATURE_COLS))), {} | |
def weighted_ensemble_predict(models, char_seq, feature_vector, features): | |
"""Ensemble prediction for CPU""" | |
predictions = [] | |
weights = { | |
'hybrid': 0.25, | |
'hybrid_fold1': 0.20, | |
'hybrid_fold2': 0.20, | |
'xgboost': 0.35 | |
} | |
if features.get('is_safe_domain', 0) == 1: | |
return 0.01, [('safe_domain_override', 0.01)] | |
for model_name, model in models.items(): | |
if model_name in weights: | |
try: | |
if 'hybrid' in model_name: | |
proba = model.predict([char_seq, feature_vector], verbose=0)[0][0] | |
else: | |
adjusted_features = feature_vector[:, :14] if feature_vector.shape[1] > 14 else feature_vector | |
proba = model.predict_proba(adjusted_features)[0][1] | |
predictions.append((model_name, proba)) | |
except Exception as e: | |
st.error(f"Prediction error in {model_name}: {str(e)}") | |
if predictions: | |
weighted_sum = sum(p * weights.get(name, 0) for name, p in predictions) | |
total_weight = sum(weights.get(name, 0) for name, _ in predictions) | |
avg_proba = weighted_sum / total_weight if total_weight > 0 else sum(p for _, p in predictions) / len(predictions) | |
else: | |
avg_proba = 0.5 | |
return avg_proba, predictions | |
def analyze_single_url(url, char_to_idx, models): | |
"""Analyze a single URL""" | |
with st.spinner(f"Analyzing URL: {url[:50]}..."): | |
start_time = time.time() | |
char_seq, feature_vector, features = preprocess_url(url, char_to_idx) | |
ensemble_proba, model_predictions = weighted_ensemble_predict( | |
models, char_seq, feature_vector, features) | |
processing_time = time.time() - start_time | |
st.subheader("Analysis Results") | |
col1, col2 = st.columns([1, 2]) | |
with col1: | |
if ensemble_proba >= 0.5: | |
st.error(f"🔴 **Threat Detected!** (Probability: {ensemble_proba:.4f})") | |
else: | |
st.success(f"🟢 **Safe URL** (Probability: {ensemble_proba:.4f})") | |
st.metric("Processing Time", f"{processing_time*1000:.2f} ms") | |
st.subheader("Key Features") | |
st.json({ | |
"URL Length": features.get('url_length', 0), | |
"Domain Length": features.get('domain_length', 0), | |
"Subdomains": features.get('subdomain_count', 0), | |
"Path Depth": features.get('path_depth', 0), | |
"Parameters": features.get('param_count', 0), | |
"Contains IP": bool(features.get('has_ip', 0)), | |
"Contains Executable": bool(features.get('has_executable', 0)), | |
"Double Extension": bool(features.get('has_double_extension', 0)), | |
"Hex Encoded": bool(features.get('hex_encoded', 0)), | |
"Safe Domain": bool(features.get('is_safe_domain', 0)), | |
"Uncommon TLD": bool(features.get('is_uncommon_tld', 0)), | |
"Entropy": features.get('entropy', 0) | |
}) | |
with col2: | |
st.subheader("Model Predictions") | |
model_df = pd.DataFrame(model_predictions, columns=['Model', 'Probability']) | |
model_df['Prediction'] = model_df['Probability'].apply( | |
lambda x: "MALICIOUS" if x >= 0.5 else "SAFE") | |
st.bar_chart(model_df.set_index('Model')['Probability']) | |
st.write("Detailed Model Results:") | |
for model_name, proba in model_predictions: | |
pred = "MALICIOUS" if proba >= 0.5 else "SAFE" | |
st.write(f"- **{model_name}**: {proba:.4f} ({pred})") | |
def analyze_batch_urls(urls, char_to_idx, models): | |
"""Analyze multiple URLs""" | |
results = [] | |
progress_bar = st.progress(0) | |
status_text = st.empty() | |
for i, url in enumerate(urls): | |
status_text.text(f"Processing {i+1}/{len(urls)}: {url[:50]}...") | |
progress_bar.progress((i + 1) / len(urls)) | |
try: | |
char_seq, feature_vector, features = preprocess_url(url, char_to_idx) | |
ensemble_proba, _ = weighted_ensemble_predict(models, char_seq, feature_vector, features) | |
results.append({ | |
'URL': url, | |
'Threat Probability': ensemble_proba, | |
'Classification': "MALICIOUS" if ensemble_proba >= 0.5 else "SAFE" | |
}) | |
except Exception as e: | |
st.error(f"Error processing {url}: {str(e)}") | |
if results: | |
results_df = pd.DataFrame(results) | |
st.dataframe(results_df) | |
csv = results_df.to_csv(index=False).encode('utf-8') | |
st.download_button( | |
"Download Results", | |
csv, | |
"url_analysis_results.csv", | |
"text/csv", | |
key='download-csv' | |
) | |
def main(): | |
st.title("🛡️ URL Threat Detector (CPU Version)") | |
st.markdown(""" | |
This tool analyzes URLs using machine learning models to detect potential threats. | |
Optimized for CPU-only environments. | |
""") | |
with st.sidebar: | |
st.header("About") | |
st.markdown(""" | |
- **Models**: Hybrid CNN+MLP, XGBoost | |
- **Features**: URL structure, lexical patterns, domain reputation | |
- **Environment**: CPU-only | |
""") | |
st.header("Example URLs") | |
st.code("https://paypal-security-alert.com/login") | |
st.code("https://github.com/features/actions") | |
# Load resources | |
with st.spinner("Loading models..."): | |
char_to_idx = load_char_mapping() | |
models = load_all_models() | |
if not char_to_idx or not models: | |
st.error("Failed to load required resources. Please check the model files.") | |
return | |
# URL input | |
st.subheader("Single URL Analysis") | |
url_input = st.text_input("Enter URL to analyze:", | |
placeholder="https://example.com", | |
label_visibility="visible") | |
if st.button("Analyze URL") and url_input: | |
analyze_single_url(url_input, char_to_idx, models) | |
# Batch analysis | |
st.subheader("Batch Analysis") | |
uploaded_file = st.file_uploader("Upload a text file with URLs (one per line)", | |
type=['txt', 'csv']) | |
if uploaded_file is not None: | |
urls = [line.decode('utf-8').strip() for line in uploaded_file if line.strip()] | |
if urls and st.button("Analyze All URLs"): | |
analyze_batch_urls(urls, char_to_idx, models) | |
if __name__ == "__main__": | |
# Configure TensorFlow logging | |
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' | |
tf.get_logger().setLevel('ERROR') | |
# Run the app | |
main() |