|
|
""" |
|
|
Model service for XSS detection - loads model from Hugging Face Hub |
|
|
""" |
|
|
import os |
|
|
import re |
|
|
import torch |
|
|
from typing import Tuple, List |
|
|
from transformers import RobertaTokenizer, RobertaForSequenceClassification |
|
|
|
|
|
|
|
|
class ModelService: |
|
|
def __init__(self): |
|
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
print(f"Using device: {self.device}") |
|
|
|
|
|
|
|
|
self.tokenizer = RobertaTokenizer.from_pretrained("microsoft/codebert-base") |
|
|
|
|
|
|
|
|
php_model_repo = os.getenv('PHP_MODEL_REPO', 'mekbus/codebert-xss-php') |
|
|
try: |
|
|
self.php_model = RobertaForSequenceClassification.from_pretrained(php_model_repo) |
|
|
self.php_model.to(self.device) |
|
|
self.php_model.eval() |
|
|
print(f"β
PHP model loaded from {php_model_repo}") |
|
|
except Exception as e: |
|
|
print(f"β οΈ PHP model not found: {e}") |
|
|
self.php_model = None |
|
|
|
|
|
|
|
|
js_model_repo = os.getenv('JS_MODEL_REPO', 'mekbus/codebert-xss-js') |
|
|
try: |
|
|
self.js_model = RobertaForSequenceClassification.from_pretrained(js_model_repo) |
|
|
self.js_model.to(self.device) |
|
|
self.js_model.eval() |
|
|
print(f"β
JS model loaded from {js_model_repo}") |
|
|
except Exception as e: |
|
|
print(f"β οΈ JS model not found: {e}") |
|
|
self.js_model = None |
|
|
|
|
|
def extract_php_blocks(self, code: str) -> str: |
|
|
"""Extract PHP code from mixed PHP/HTML and remove comments""" |
|
|
php_blocks = re.findall(r'<\?(?:php)?(.*?)(?:\?>|$)', code, re.DOTALL | re.IGNORECASE) |
|
|
|
|
|
if php_blocks: |
|
|
processed_blocks = [] |
|
|
for block in php_blocks: |
|
|
block = block.strip() |
|
|
if block.startswith('='): |
|
|
block = 'echo ' + block[1:].strip() + ';' |
|
|
processed_blocks.append(block) |
|
|
php_code = '\n'.join(processed_blocks) |
|
|
else: |
|
|
php_code = code |
|
|
|
|
|
|
|
|
php_code = re.sub(r'/\*.*?\*/', '', php_code, flags=re.DOTALL) |
|
|
php_code = re.sub(r'//.*$', '', php_code, flags=re.MULTILINE) |
|
|
php_code = re.sub(r'#.*$', '', php_code, flags=re.MULTILINE) |
|
|
php_code = re.sub(r'\n\s*\n+', '\n', php_code.strip()) |
|
|
|
|
|
return php_code |
|
|
|
|
|
def extract_js_code(self, code: str) -> str: |
|
|
"""Extract and clean JavaScript code""" |
|
|
code = re.sub(r'/\*.*?\*/', '', code, flags=re.DOTALL) |
|
|
code = re.sub(r'//.*$', '', code, flags=re.MULTILINE) |
|
|
code = re.sub(r'\n\s*\n+', '\n', code.strip()) |
|
|
return code |
|
|
|
|
|
def chunk_code(self, code: str, max_tokens: int = 400, overlap: int = 50) -> List[str]: |
|
|
"""Split large code into overlapping chunks""" |
|
|
lines = code.split('\n') |
|
|
chunks = [] |
|
|
max_lines = 50 |
|
|
overlap_lines = 6 |
|
|
|
|
|
i = 0 |
|
|
while i < len(lines): |
|
|
chunk_lines = lines[i:i + max_lines] |
|
|
chunk = '\n'.join(chunk_lines) |
|
|
if chunk.strip(): |
|
|
chunks.append(chunk) |
|
|
i += max_lines - overlap_lines |
|
|
|
|
|
return chunks if chunks else [code] |
|
|
|
|
|
def predict_single(self, code: str, model) -> Tuple[float, float]: |
|
|
"""Make a single prediction""" |
|
|
inputs = self.tokenizer( |
|
|
code, |
|
|
return_tensors='pt', |
|
|
truncation=True, |
|
|
max_length=512, |
|
|
padding='max_length' |
|
|
) |
|
|
inputs = {k: v.to(self.device) for k, v in inputs.items()} |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model(**inputs) |
|
|
probs = torch.softmax(outputs.logits, dim=1) |
|
|
return probs[0][0].item(), probs[0][1].item() |
|
|
|
|
|
def predict(self, code: str, language: str) -> Tuple[bool, float, str]: |
|
|
"""Predict if code is vulnerable""" |
|
|
result = self.predict_multi(code, language) |
|
|
if result['vulnerabilities']: |
|
|
max_vuln = max(result['vulnerabilities'], key=lambda x: x['confidence']) |
|
|
return True, max_vuln['confidence'], "VULNERABLE" |
|
|
else: |
|
|
return False, result['max_confidence'], "SAFE" |
|
|
|
|
|
def predict_multi(self, code: str, language: str) -> dict: |
|
|
"""Predict vulnerabilities - returns multiple if found using token-based chunking""" |
|
|
if language == 'php': |
|
|
model = self.php_model |
|
|
code = self.extract_php_blocks(code) |
|
|
elif language in ['js', 'javascript']: |
|
|
model = self.js_model |
|
|
code = self.extract_js_code(code) |
|
|
else: |
|
|
raise ValueError(f"Unsupported language: {language}") |
|
|
|
|
|
if model is None: |
|
|
raise RuntimeError(f"{language.upper()} model not loaded") |
|
|
|
|
|
vulnerabilities = [] |
|
|
max_vuln_prob = 0.0 |
|
|
threshold = 0.5 |
|
|
max_length = 512 |
|
|
chunk_overlap = 50 |
|
|
|
|
|
|
|
|
tokens = self.tokenizer.encode(code, add_special_tokens=False) |
|
|
|
|
|
|
|
|
if len(tokens) <= max_length - 2: |
|
|
safe_prob, vuln_prob = self.predict_single(code, model) |
|
|
max_vuln_prob = vuln_prob |
|
|
if vuln_prob >= threshold: |
|
|
vulnerabilities.append({ |
|
|
'chunk_id': 1, |
|
|
'start_line': 1, |
|
|
'end_line': len(code.split('\n')), |
|
|
'confidence': vuln_prob |
|
|
}) |
|
|
else: |
|
|
|
|
|
chunk_size = max_length - 2 |
|
|
stride = chunk_size - chunk_overlap |
|
|
chunks = [] |
|
|
|
|
|
for i in range(0, len(tokens), stride): |
|
|
chunk_tokens = tokens[i:i + chunk_size] |
|
|
if len(chunk_tokens) < 50: |
|
|
continue |
|
|
chunks.append(chunk_tokens) |
|
|
|
|
|
print(f"π Long {language.upper()} code ({len(tokens)} tokens) β {len(chunks)} chunks") |
|
|
|
|
|
lines = code.split('\n') |
|
|
total_lines = len(lines) |
|
|
lines_per_chunk = max(1, total_lines // len(chunks)) if chunks else total_lines |
|
|
|
|
|
for i, chunk_tokens in enumerate(chunks): |
|
|
|
|
|
chunk_text = self.tokenizer.decode(chunk_tokens) |
|
|
safe_prob, vuln_prob = self.predict_single(chunk_text, model) |
|
|
|
|
|
if vuln_prob > max_vuln_prob: |
|
|
max_vuln_prob = vuln_prob |
|
|
|
|
|
if vuln_prob >= threshold: |
|
|
start_line = i * lines_per_chunk + 1 |
|
|
end_line = min(start_line + lines_per_chunk - 1, total_lines) |
|
|
vulnerabilities.append({ |
|
|
'chunk_id': i + 1, |
|
|
'start_line': start_line, |
|
|
'end_line': end_line, |
|
|
'confidence': vuln_prob |
|
|
}) |
|
|
|
|
|
|
|
|
if vulnerabilities: |
|
|
scores = [f"{v['confidence']:.1%}" for v in vulnerabilities] |
|
|
print(f"π Chunk scores: {scores}") |
|
|
else: |
|
|
print("π Chunk scores: all safe") |
|
|
print(f"π Max vulnerability score: {max_vuln_prob:.1%}") |
|
|
|
|
|
return { |
|
|
'is_vulnerable': len(vulnerabilities) > 0, |
|
|
'max_confidence': max_vuln_prob, |
|
|
'vulnerabilities': vulnerabilities |
|
|
} |
|
|
|