STA-AI / src /security_system.py
saemstunes's picture
Update src/security_system.py
8ff5fb7 verified
import re
import time
from datetime import datetime, timedelta
from typing import Dict, List, Optional, Tuple
import logging
import hashlib
class AdvancedSecuritySystem:
"""
Advanced security system for input validation, rate limiting, and threat detection.
Protects the AI system from abuse and malicious inputs.
"""
def __init__(self):
self.rate_limits = {}
self.suspicious_ips = {}
self.security_log = []
# Suspicious patterns for input validation
self.suspicious_patterns = [
# SQL Injection patterns
r"(?i)(union.*select|select.*from|insert.*into|delete.*from|drop.*table)",
r"(?i)(or.*1=1|and.*1=1|exec.*\(|xp_cmdshell)",
r"(\b)(DROP|DELETE|INSERT|UPDATE|ALTER)(\b)",
# XSS patterns
r"(?i)(script|javascript|onload|onerror|onclick|alert\(|document\.cookie)",
r"<.*>.*</.*>", # HTML tags
# Command injection
r"[;&|`]\s*\w+",
r"\$\(.*\)",
# Path traversal
r"\.\./|\.\.\\",
# Sensitive data patterns
r"(?i)(password|token|key|secret|auth|credential)",
r"\b\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}\b", # IP addresses
# Excessive length or repetition
r".{10000,}", # Very long inputs
r"(.)\1{50,}", # Repeated characters
# Admin/privilege patterns
r"(?i)(admin|root|sudo|su -|chmod|chown)"
]
# Rate limiting configuration
self.rate_limit_config = {
"default": {"requests_per_minute": 60, "burst_capacity": 10},
"anonymous": {"requests_per_minute": 30, "burst_capacity": 5},
"suspicious": {"requests_per_minute": 10, "burst_capacity": 2}
}
self.setup_logging()
def setup_logging(self):
"""Setup security logging"""
self.logger = logging.getLogger(__name__)
self.logger.setLevel(logging.INFO)
def check_request(self, query: str, user_id: str, ip_address: Optional[str] = None) -> Dict[str, any]:
"""
Comprehensive security check for incoming requests.
Args:
query: User's query text
user_id: User identifier
ip_address: Optional IP address for IP-based checks
Returns:
Security assessment result
"""
result = {
"is_suspicious": False,
"alerts": [],
"risk_score": 0,
"allowed": True,
"rate_limit_info": {}
}
# Rate limiting check
rate_limit_result = self.check_rate_limit(user_id, ip_address)
if not rate_limit_result["allowed"]:
result["is_suspicious"] = True
result["allowed"] = False
result["alerts"].append("Rate limit exceeded")
result["risk_score"] = 100
result["rate_limit_info"] = rate_limit_result
return result
result["rate_limit_info"] = rate_limit_result
# Input validation and pattern matching
validation_result = self.validate_input(query, user_id)
result["alerts"].extend(validation_result["alerts"])
result["risk_score"] += validation_result["risk_score"]
# IP reputation check (if IP provided)
if ip_address:
ip_result = self.check_ip_reputation(ip_address)
result["alerts"].extend(ip_result["alerts"])
result["risk_score"] += ip_result["risk_score"]
# Determine overall suspicion
if result["risk_score"] >= 50:
result["is_suspicious"] = True
if result["risk_score"] >= 80:
result["allowed"] = False
# Log security event
self.log_security_event(user_id, ip_address, query, result)
return result
def check_rate_limit(self, user_id: str, ip_address: Optional[str] = None) -> Dict[str, any]:
"""Check rate limits for user and/or IP"""
current_time = datetime.now()
user_key = f"user_{user_id}"
ip_key = f"ip_{ip_address}" if ip_address else None
# Get rate limit configuration
user_config = self.rate_limit_config.get("default")
if user_id == "anonymous":
user_config = self.rate_limit_config.get("anonymous", user_config)
# Check if user is marked as suspicious
if self.is_suspicious_user(user_id) or (ip_address and self.is_suspicious_ip(ip_address)):
user_config = self.rate_limit_config.get("suspicious", user_config)
# Clean old entries for user
self.rate_limits[user_key] = [
t for t in self.rate_limits.get(user_key, [])
if current_time - t < timedelta(minutes=1)
]
# Clean old entries for IP (if provided)
if ip_key:
self.rate_limits[ip_key] = [
t for t in self.rate_limits.get(ip_key, [])
if current_time - t < timedelta(minutes=1)
]
# Check user rate limit
user_requests = len(self.rate_limits.get(user_key, []))
user_allowed = user_requests < user_config["requests_per_minute"]
# Check IP rate limit (if IP provided)
ip_allowed = True
if ip_key:
ip_requests = len(self.rate_limits.get(ip_key, []))
ip_allowed = ip_requests < user_config["requests_per_minute"]
allowed = user_allowed and ip_allowed
# Add current request to counters if allowed
if allowed:
self.rate_limits.setdefault(user_key, []).append(current_time)
if ip_key:
self.rate_limits.setdefault(ip_key, []).append(current_time)
return {
"allowed": allowed,
"user_requests": user_requests,
"user_limit": user_config["requests_per_minute"],
"ip_requests": len(self.rate_limits.get(ip_key, [])) if ip_key else 0,
"ip_limit": user_config["requests_per_minute"] if ip_key else "N/A",
"retry_after": 60 if not allowed else 0
}
def validate_input(self, query: str, user_id: str) -> Dict[str, any]:
"""Validate and analyze user input"""
result = {
"alerts": [],
"risk_score": 0
}
# Pattern matching
for pattern in self.suspicious_patterns:
matches = re.findall(pattern, query)
if matches:
alert_msg = f"Suspicious pattern detected: {pattern[:50]}..."
result["alerts"].append(alert_msg)
result["risk_score"] += 20
# Query length analysis
query_length = len(query)
if query_length > 10000:
result["alerts"].append("Excessively long query detected")
result["risk_score"] += 30
elif query_length > 5000:
result["alerts"].append("Very long query detected")
result["risk_score"] += 15
# Special character analysis
special_chars = len(re.findall(r'[^\w\s\.\?\!]', query))
special_char_ratio = special_chars / max(len(query), 1)
if special_char_ratio > 0.3:
result["alerts"].append("High percentage of special characters")
result["risk_score"] += 25
elif special_char_ratio > 0.2:
result["alerts"].append("Elevated special character usage")
result["risk_score"] += 10
# Entropy analysis (for encrypted/encoded content)
entropy = self.calculate_entropy(query)
if entropy > 6.0: # High entropy might indicate encoded/encrypted content
result["alerts"].append("High entropy content detected")
result["risk_score"] += 20
return result
def check_ip_reputation(self, ip_address: str) -> Dict[str, any]:
"""Check IP reputation (basic implementation)"""
result = {
"alerts": [],
"risk_score": 0
}
# Check if IP is in suspicious list
if self.is_suspicious_ip(ip_address):
result["alerts"].append("IP address has suspicious history")
result["risk_score"] += 40
# Simple IP pattern check (private IPs, localhost, etc.)
if ip_address in ["127.0.0.1", "localhost", "0.0.0.0"]:
result["alerts"].append("Local IP address detected")
result["risk_score"] += 10
# Check for rapid requests from this IP
ip_key = f"ip_{ip_address}"
recent_requests = len(self.rate_limits.get(ip_key, []))
if recent_requests > 50: # High volume from single IP
result["alerts"].append("High request volume from IP")
result["risk_score"] += 15
return result
def calculate_entropy(self, text: str) -> float:
"""Calculate Shannon entropy of text (for detecting encoded content)"""
if not text:
return 0.0
import math
entropy = 0.0
text_length = len(text)
for char in set(text):
p_x = float(text.count(char)) / text_length
if p_x > 0:
entropy += - p_x * math.log2(p_x)
return entropy
def is_suspicious_user(self, user_id: str) -> bool:
"""Check if user is marked as suspicious"""
# In a real implementation, this would check a database
# For now, use simple in-memory tracking
user_key = f"user_{user_id}"
return self.suspicious_ips.get(user_key, 0) > 5
def is_suspicious_ip(self, ip_address: str) -> bool:
"""Check if IP is marked as suspicious"""
ip_key = f"ip_{ip_address}"
return self.suspicious_ips.get(ip_key, 0) > 3
def mark_suspicious(self, user_id: str, ip_address: Optional[str] = None, reason: str = ""):
"""Mark user or IP as suspicious"""
if user_id:
user_key = f"user_{user_id}"
self.suspicious_ips[user_key] = self.suspicious_ips.get(user_key, 0) + 1
if ip_address:
ip_key = f"ip_{ip_address}"
self.suspicious_ips[ip_key] = self.suspicious_ips.get(ip_key, 0) + 1
self.logger.warning(f"Marked as suspicious - User: {user_id}, IP: {ip_address}, Reason: {reason}")
def log_security_event(self, user_id: str, ip_address: Optional[str], query: str, result: Dict):
"""Log security event for auditing"""
event = {
"timestamp": datetime.now().isoformat(),
"user_id": user_id,
"ip_address": ip_address,
"query_preview": query[:100] + "..." if len(query) > 100 else query,
"query_length": len(query),
"risk_score": result["risk_score"],
"alerts": result["alerts"],
"allowed": result["allowed"],
"is_suspicious": result["is_suspicious"]
}
self.security_log.append(event)
# Keep only last 1000 events
if len(self.security_log) > 1000:
self.security_log = self.security_log[-1000:]
# Log to security logger if high risk
if result["risk_score"] >= 50:
self.logger.warning(f"Security alert: User {user_id} - Score: {result['risk_score']} - Alerts: {result['alerts']}")
def get_security_stats(self) -> Dict[str, any]:
"""Get security statistics"""
recent_events = [e for e in self.security_log
if datetime.now() - datetime.fromisoformat(e["timestamp"]) < timedelta(hours=24)]
blocked_events = [e for e in recent_events if not e["allowed"]]
suspicious_events = [e for e in recent_events if e["is_suspicious"]]
return {
"total_events_24h": len(recent_events),
"blocked_requests_24h": len(blocked_events),
"suspicious_requests_24h": len(suspicious_events),
"current_suspicious_users": len([k for k, v in self.suspicious_ips.items() if k.startswith("user_") and v > 0]),
"current_suspicious_ips": len([k for k, v in self.suspicious_ips.items() if k.startswith("ip_") and v > 0]),
"rate_limits_tracked": len(self.rate_limits)
}
def sanitize_input(self, text: str) -> str:
"""Sanitize user input to prevent injection attacks"""
if not text:
return ""
# Remove potentially dangerous characters
sanitized = re.sub(r'[<>"\']', '', text)
# Remove SQL injection patterns
sanitized = re.sub(r'(\b)(DROP|DELETE|INSERT|UPDATE|ALTER|EXEC)(\b)', '', sanitized, flags=re.IGNORECASE)
# Remove JavaScript and HTML patterns
sanitized = re.sub(r'(javascript|script|onload|onerror|onclick)', '', sanitized, flags=re.IGNORECASE)
# Remove command injection patterns
sanitized = re.sub(r'[;&|`]\s*\w+', '', sanitized)
return sanitized.strip()
def reset_rate_limits(self, user_id: Optional[str] = None, ip_address: Optional[str] = None):
"""Reset rate limits for specific user or IP"""
if user_id:
user_key = f"user_{user_id}"
if user_key in self.rate_limits:
del self.rate_limits[user_key]
if ip_address:
ip_key = f"ip_{ip_address}"
if ip_key in self.rate_limits:
del self.rate_limits[ip_key]