File size: 16,421 Bytes
16d29c7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82f26bf
16d29c7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82f26bf
 
16d29c7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3d3c15a
 
 
16d29c7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
import streamlit as st
import re
import math
import time
import os
import joblib
import numpy as np
import pandas as pd
import torch
import tensorflow as tf
from collections import Counter
from tensorflow.keras.models import load_model
from tensorflow.keras.preprocessing.sequence import pad_sequences
import tldextract
from rapidfuzz import fuzz, process

# Set page config
st.set_page_config(
    page_title="URL Threat Detector",
    page_icon="🛡️",
    layout="wide",
    initial_sidebar_state="expanded"
)

# Disable GPU usage for TensorFlow and PyTorch
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
tf.config.set_visible_devices([], 'GPU')

# Global configuration
MAX_LEN = 200
FIXED_FEATURE_COLS = [
    'url_length', 'domain_length', 'subdomain_count', 'path_depth',
    'param_count', 'has_ip', 'has_executable', 'has_double_extension',
    'hex_encoded', 'digit_ratio', 'special_char_ratio', 'entropy',
    'is_safe_domain', 'is_uncommon_tld'
]

# Enhanced domain and TLD lists
SAFE_DOMAINS = {
    'google.com', 'google.co.in', 'google.co.uk', 'google.fr', 'google.de',
    'amazon.com', 'amazon.in', 'amazon.co.uk', 'amazon.de', 'amazon.fr',
    'wikipedia.org', 'github.com', 'python.org', 'irs.gov', 'adobe.com',
    'steampowered.com', 'imdb.com', 'weather.com', 'archive.org', 'cdc.gov',
    'microsoft.com', 'apple.com', 'youtube.com', 'facebook.com', 'twitter.com',
    'linkedin.com', 'instagram.com', 'netflix.com', 'reddit.com', 'stackoverflow.com',
    'google.com', 'amazon.in', 'linkedin.com'
}

COMMON_TLDS = {
    'com', 'org', 'net', 'gov', 'edu', 'mil', 'co', 'io', 'ai', 'in',
    'uk', 'us', 'ca', 'au', 'de', 'fr', 'es', 'it', 'nl', 'jp', 'cn',
    'br', 'mx', 'ru', 'ch', 'se', 'no', 'dk', 'fi', 'be', 'at', 'nz'
}

# Initialize tldextract
tld_extractor = tldextract.TLDExtract()

@st.cache_resource
def load_char_mapping():
    char_to_idx_path = 'char_to_idx.pkl'
    if not os.path.exists(char_to_idx_path):
        st.error(f"Character mapping file not found: {char_to_idx_path}")
        return None
    return joblib.load(char_to_idx_path)

@st.cache_resource
def load_all_models():
    """Load models with CPU optimization"""
    models = {}
    model_dir = "models"

    if not os.path.exists(model_dir):
        os.makedirs(model_dir)
        st.warning(f"Created model directory: {model_dir}")

    # Hybrid models
    hybrid_models = {
        'hybrid': 'hybrid_model.h5',
        'hybrid_fold1': 'best_hybrid_fold1.h5',
        'hybrid_fold2': 'best_hybrid_fold2.h5'
    }

    for name, file in hybrid_models.items():
        path = os.path.join(model_dir, file)
        if os.path.exists(path):
            try:
                models[name] = load_model(path)
                st.success(f"Loaded {name}")
            except Exception as e:
                st.error(f"Error loading {name}: {str(e)}")
        else:
            st.warning(f"Model file not found: {path}")

    # Traditional models
    traditional_models = {
        'random_forest': 'random_forest_model.pkl',
        'xgboost': 'xgboost_model.pkl',
    }

    for name, file in traditional_models.items():
        path = os.path.join(model_dir, file)
        if os.path.exists(path):
            try:
                models[name] = joblib.load(path)
                st.success(f"Loaded {name}")
            except Exception as e:
                st.error(f"Error loading {name}: {str(e)}")
        else:
            st.warning(f"Model file not found: {path}")

    return models

def normalize_url(url):
    """Normalize URL with proper indentation and parenthesis"""
    try:
        is_https = url.lower().startswith('https://')
        url = url.lower()

        prefixes = ['http://', 'ftp://', 'www.', 'ww2.', 'web.']
        for prefix in prefixes:
            if url.startswith(prefix):
                url = url[len(prefix):]

        if is_https:
            url = "https://" + url

        url = url.split('#')[0]

        if '?' in url:
            base, query = url.split('?', 1)
            if not any(sd in base for sd in SAFE_DOMAINS):
                params = [p for p in query.split('&') if '=' in p]
                essential_params = [p for p in params if any(
                    kw in p for kw in ['id=', 'ref=', 'token='])]
                url = base + ('?' + '&'.join(essential_params) if essential_params else ''
        
        return re.sub(r'/{2,}', '/', url)
    except Exception:
        return url

def extract_url_components(url):
    """Robust URL parsing"""
    try:
        extracted = tld_extractor(url)
        subdomain = extracted.subdomain
        domain = extracted.domain
        suffix = extracted.suffix

        path = ""
        query = ""
        if "/" in url:
            path_start = url.find("/", url.find("//") + 2) if "//" in url else url.find("/")
            if path_start != -1:
                path_query = url[path_start:]
                if "?" in path_query:
                    path, query = path_query.split("?", 1)
                else:
                    path = path_query

        if not domain and subdomain:
            domain_parts = subdomain.split('.')
            if len(domain_parts) > 1:
                domain = domain_parts[-1]
                subdomain = '.'.join(domain_parts[:-1])

        return {
            'subdomain': subdomain,
            'domain': domain,
            'suffix': suffix,
            'path': path,
            'query': query
        }
    except:
        return {
            'subdomain': '',
            'domain': '',
            'suffix': '',
            'path': '',
            'query': ''
        }

def calculate_entropy(s):
    """Compute Shannon entropy"""
    if not s:
        return 0
    try:
        p, lns = Counter(s), float(len(s))
        return -sum(count/lns * math.log(count/lns, 2) for count in p.values())
    except:
        return 0

def fuzzy_domain_match(domain):
    """Safe domain matching"""
    if domain in SAFE_DOMAINS:
        return True

    domain_parts = domain.split('.')
    if len(domain_parts) > 2:
        base_domain = '.'.join(domain_parts[-2:])
        if base_domain in SAFE_DOMAINS:
            return True

    best_match, score, _ = process.extractOne(domain, SAFE_DOMAINS, scorer=fuzz.WRatio)
    return score > 85

def extract_robust_features(url):
    """Feature extraction optimized for CPU"""
    try:
        clean_url = re.sub(r'[^\x00-\x7F]+', '', str(url))
        normalized = normalize_url(clean_url)
        components = extract_url_components(clean_url)
        full_domain = f"{components['domain']}.{components['suffix']}" if components['suffix'] else components['domain']

        # Structural features
        url_length = len(clean_url)
        domain_length = len(components['domain'])
        subdomain_count = len(components['subdomain'].split('.')) if components['subdomain'] else 0
        path_depth = components['path'].count('/') if components['path'] else 0
        param_count = len(components['query'].split('&')) if components['query'] else 0

        # Security features
        has_ip = 1 if re.match(r'^\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}$', components['domain']) else 0
        has_executable = 1 if re.search(r'\.(exe|js|jar|bat|sh|py|dll)$', components['path'], re.I) else 0
        has_double_extension = 1 if re.search(r'\.\w+\.\w+$', components['path'], re.I) else 0
        hex_encoded = 1 if re.search(r'%[0-9a-f]{2}', normalized, re.I) else 0

        # Lexical features
        digit_count = sum(c.isdigit() for c in normalized)
        special_chars = sum(not (c.isalnum() or c in ' ./-') for c in normalized)

        digit_ratio = digit_count / url_length if url_length > 0 else 0
        special_char_ratio = special_chars / url_length if url_length > 0 else 0
        entropy = calculate_entropy(normalized)

        # Domain reputation
        is_safe_domain = 1 if fuzzy_domain_match(full_domain) else 0
        is_uncommon_tld = 1 if components['suffix'] and components['suffix'] not in COMMON_TLDS else 0

        if url.startswith('https://') and full_domain in SAFE_DOMAINS:
            is_safe_domain = 1

        return {
            'url_length': url_length,
            'domain_length': domain_length,
            'subdomain_count': subdomain_count,
            'path_depth': path_depth,
            'param_count': param_count,
            'has_ip': has_ip,
            'has_executable': has_executable,
            'has_double_extension': has_double_extension,
            'hex_encoded': hex_encoded,
            'digit_ratio': digit_ratio,
            'special_char_ratio': special_char_ratio,
            'entropy': entropy,
            'is_safe_domain': is_safe_domain,
            'is_uncommon_tld': is_uncommon_tld
        }
    except Exception as e:
        st.error(f"Feature extraction error: {str(e)}")
        return {col: 0 for col in FIXED_FEATURE_COLS}

def preprocess_url(url, char_to_idx):
    """URL preprocessing for CPU"""
    try:
        clean_url = re.sub(r'[^\x00-\x7F]+', '', str(url))
        normalized = normalize_url(clean_url)
        features = extract_robust_features(clean_url)
        feature_vector = np.array([features.get(col, 0) for col in FIXED_FEATURE_COLS]).reshape(1, -1)

        char_seq = [char_to_idx.get(c, 0) for c in normalized]
        char_seq = pad_sequences([char_seq], maxlen=MAX_LEN, padding='post', truncating='post')

        return char_seq, feature_vector, features
    except Exception as e:
        st.error(f"Preprocessing error: {str(e)}")
        return np.zeros((1, MAX_LEN)), np.zeros((1, len(FIXED_FEATURE_COLS))), {}

def weighted_ensemble_predict(models, char_seq, feature_vector, features):
    """Ensemble prediction for CPU"""
    predictions = []
    weights = {
        'hybrid': 0.25,
        'hybrid_fold1': 0.20,
        'hybrid_fold2': 0.20,
        'xgboost': 0.35
    }

    if features.get('is_safe_domain', 0) == 1:
        return 0.01, [('safe_domain_override', 0.01)]

    for model_name, model in models.items():
        if model_name in weights:
            try:
                if 'hybrid' in model_name:
                    proba = model.predict([char_seq, feature_vector], verbose=0)[0][0]
                else:
                    adjusted_features = feature_vector[:, :14] if feature_vector.shape[1] > 14 else feature_vector
                    proba = model.predict_proba(adjusted_features)[0][1]
                predictions.append((model_name, proba))
            except Exception as e:
                st.error(f"Prediction error in {model_name}: {str(e)}")

    if predictions:
        weighted_sum = sum(p * weights.get(name, 0) for name, p in predictions)
        total_weight = sum(weights.get(name, 0) for name, _ in predictions)
        avg_proba = weighted_sum / total_weight if total_weight > 0 else sum(p for _, p in predictions) / len(predictions)
    else:
        avg_proba = 0.5

    return avg_proba, predictions

def analyze_single_url(url, char_to_idx, models):
    """Analyze a single URL"""
    with st.spinner(f"Analyzing URL: {url[:50]}..."):
        start_time = time.time()
        
        char_seq, feature_vector, features = preprocess_url(url, char_to_idx)
        ensemble_proba, model_predictions = weighted_ensemble_predict(
            models, char_seq, feature_vector, features)
        
        processing_time = time.time() - start_time
        
        st.subheader("Analysis Results")
        col1, col2 = st.columns([1, 2])
        
        with col1:
            if ensemble_proba >= 0.5:
                st.error(f"🔴 **Threat Detected!** (Probability: {ensemble_proba:.4f})")
            else:
                st.success(f"🟢 **Safe URL** (Probability: {ensemble_proba:.4f})")
            
            st.metric("Processing Time", f"{processing_time*1000:.2f} ms")
            
            st.subheader("Key Features")
            st.json({
                "URL Length": features.get('url_length', 0),
                "Domain Length": features.get('domain_length', 0),
                "Subdomains": features.get('subdomain_count', 0),
                "Path Depth": features.get('path_depth', 0),
                "Parameters": features.get('param_count', 0),
                "Contains IP": bool(features.get('has_ip', 0)),
                "Contains Executable": bool(features.get('has_executable', 0)),
                "Double Extension": bool(features.get('has_double_extension', 0)),
                "Hex Encoded": bool(features.get('hex_encoded', 0)),
                "Safe Domain": bool(features.get('is_safe_domain', 0)),
                "Uncommon TLD": bool(features.get('is_uncommon_tld', 0)),
                "Entropy": features.get('entropy', 0)
            })
        
        with col2:
            st.subheader("Model Predictions")
            model_df = pd.DataFrame(model_predictions, columns=['Model', 'Probability'])
            model_df['Prediction'] = model_df['Probability'].apply(
                lambda x: "MALICIOUS" if x >= 0.5 else "SAFE")
            
            st.bar_chart(model_df.set_index('Model')['Probability'])
            
            st.write("Detailed Model Results:")
            for model_name, proba in model_predictions:
                pred = "MALICIOUS" if proba >= 0.5 else "SAFE"
                st.write(f"- **{model_name}**: {proba:.4f} ({pred})")

def analyze_batch_urls(urls, char_to_idx, models):
    """Analyze multiple URLs"""
    results = []
    progress_bar = st.progress(0)
    status_text = st.empty()
    
    for i, url in enumerate(urls):
        status_text.text(f"Processing {i+1}/{len(urls)}: {url[:50]}...")
        progress_bar.progress((i + 1) / len(urls))
        
        try:
            char_seq, feature_vector, features = preprocess_url(url, char_to_idx)
            ensemble_proba, _ = weighted_ensemble_predict(models, char_seq, feature_vector, features)
            results.append({
                'URL': url,
                'Threat Probability': ensemble_proba,
                'Classification': "MALICIOUS" if ensemble_proba >= 0.5 else "SAFE"
            })
        except Exception as e:
            st.error(f"Error processing {url}: {str(e)}")
    
    if results:
        results_df = pd.DataFrame(results)
        st.dataframe(results_df)
        
        csv = results_df.to_csv(index=False).encode('utf-8')
        st.download_button(
            "Download Results",
            csv,
            "url_analysis_results.csv",
            "text/csv",
            key='download-csv'
        )

def main():
    st.title("🛡️ URL Threat Detector (CPU Version)")
    st.markdown("""
    This tool analyzes URLs using machine learning models to detect potential threats.
    Optimized for CPU-only environments.
    """)
    
    with st.sidebar:
        st.header("About")
        st.markdown("""
        - **Models**: Hybrid CNN+MLP, XGBoost
        - **Features**: URL structure, lexical patterns, domain reputation
        - **Environment**: CPU-only
        """)
        
        st.header("Example URLs")
        st.code("https://paypal-security-alert.com/login")
        st.code("https://github.com/features/actions")
    
    # Load resources
    with st.spinner("Loading models..."):
        char_to_idx = load_char_mapping()
        models = load_all_models()
    
    if not char_to_idx or not models:
        st.error("Failed to load required resources. Please check the model files.")
        return
    
    # URL input
    st.subheader("Single URL Analysis")
        url_input = st.text_input("Enter URL to analyze:", 
                            placeholder="https://example.com",
                            label_visibility="visible")
    
    if st.button("Analyze URL") and url_input:
        analyze_single_url(url_input, char_to_idx, models)
    
    # Batch analysis
    st.subheader("Batch Analysis")
    uploaded_file = st.file_uploader("Upload a text file with URLs (one per line)", 
                                   type=['txt', 'csv'])
    
    if uploaded_file is not None:
        urls = [line.decode('utf-8').strip() for line in uploaded_file if line.strip()]
        if urls and st.button("Analyze All URLs"):
            analyze_batch_urls(urls, char_to_idx, models)

if __name__ == "__main__":
    # Configure TensorFlow logging
    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
    tf.get_logger().setLevel('ERROR')
    
    # Run the app
    main()