saemstunes commited on
Commit
8ff5fb7
·
verified ·
1 Parent(s): 9b24530

Update src/security_system.py

Browse files
Files changed (1) hide show
  1. src/security_system.py +306 -42
src/security_system.py CHANGED
@@ -1,87 +1,351 @@
1
  import re
2
  import time
3
  from datetime import datetime, timedelta
4
- from typing import Dict, List, Optional
5
  import logging
 
6
 
7
- class SecuritySystem:
8
- """Security system for input validation and rate limiting"""
 
 
 
9
 
10
  def __init__(self):
11
  self.rate_limits = {}
 
 
 
 
12
  self.suspicious_patterns = [
13
- r"(?i)(password|token|key|secret)",
14
- r"(?i)(delete|drop|alter|update).*table",
15
- r"(?i)(script|javascript|onload|onerror)",
16
- r"\b\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}\b",
17
- r"(?i)(admin|root|sudo)"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  ]
 
 
 
 
 
 
 
 
19
  self.setup_logging()
20
 
21
  def setup_logging(self):
22
- """Setup logging"""
23
  self.logger = logging.getLogger(__name__)
 
24
 
25
- def check_request(self, query: str, user_id: str) -> Dict:
26
- """Check request for security issues"""
 
 
 
 
 
 
 
 
 
 
27
  result = {
28
  "is_suspicious": False,
29
  "alerts": [],
30
- "risk_score": 0
 
 
31
  }
32
 
33
- # Rate limiting
34
- if not self.check_rate_limit(user_id):
 
35
  result["is_suspicious"] = True
 
36
  result["alerts"].append("Rate limit exceeded")
37
  result["risk_score"] = 100
 
38
  return result
39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  # Pattern matching
41
  for pattern in self.suspicious_patterns:
42
- if re.search(pattern, query):
43
- result["alerts"].append(f"Suspicious pattern: {pattern}")
 
 
44
  result["risk_score"] += 20
45
 
46
  # Query length analysis
47
- if len(query) > 10000:
48
- result["alerts"].append("Excessively long query")
 
49
  result["risk_score"] += 30
 
 
 
50
 
51
  # Special character analysis
52
- special_chars = len(re.findall(r'[^\w\s]', query))
53
- if special_chars > len(query) * 0.3:
 
 
54
  result["alerts"].append("High percentage of special characters")
55
  result["risk_score"] += 25
 
 
 
56
 
57
- # Determine if suspicious
58
- if result["risk_score"] >= 50:
59
- result["is_suspicious"] = True
 
 
60
 
61
  return result
62
 
63
- def check_rate_limit(self, user_id: str, requests_per_minute: int = 60) -> bool:
64
- """Check rate limit for user"""
65
- current_time = datetime.now()
 
 
 
66
 
67
- # Clean old entries
68
- self.rate_limits[user_id] = [
69
- t for t in self.rate_limits.get(user_id, [])
70
- if current_time - t < timedelta(minutes=1)
71
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
 
73
- # Check rate limit
74
- if len(self.rate_limits[user_id]) >= requests_per_minute:
75
- return False
76
 
77
- # Add current request
78
- self.rate_limits[user_id].append(current_time)
79
- return True
 
 
 
 
 
80
 
81
  def sanitize_input(self, text: str) -> str:
82
- """Sanitize user input"""
 
 
 
83
  # Remove potentially dangerous characters
84
- sanitized = re.sub(r'[<>"\'&]', '', text)
85
- sanitized = re.sub(r'(\b)(DROP|DELETE|INSERT|UPDATE)(\b)', '', sanitized, flags=re.IGNORECASE)
86
- sanitized = re.sub(r';\s*\w+', '', sanitized)
87
- return sanitized.strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import re
2
  import time
3
  from datetime import datetime, timedelta
4
+ from typing import Dict, List, Optional, Tuple
5
  import logging
6
+ import hashlib
7
 
8
+ class AdvancedSecuritySystem:
9
+ """
10
+ Advanced security system for input validation, rate limiting, and threat detection.
11
+ Protects the AI system from abuse and malicious inputs.
12
+ """
13
 
14
  def __init__(self):
15
  self.rate_limits = {}
16
+ self.suspicious_ips = {}
17
+ self.security_log = []
18
+
19
+ # Suspicious patterns for input validation
20
  self.suspicious_patterns = [
21
+ # SQL Injection patterns
22
+ r"(?i)(union.*select|select.*from|insert.*into|delete.*from|drop.*table)",
23
+ r"(?i)(or.*1=1|and.*1=1|exec.*\(|xp_cmdshell)",
24
+ r"(\b)(DROP|DELETE|INSERT|UPDATE|ALTER)(\b)",
25
+
26
+ # XSS patterns
27
+ r"(?i)(script|javascript|onload|onerror|onclick|alert\(|document\.cookie)",
28
+ r"<.*>.*</.*>", # HTML tags
29
+
30
+ # Command injection
31
+ r"[;&|`]\s*\w+",
32
+ r"\$\(.*\)",
33
+
34
+ # Path traversal
35
+ r"\.\./|\.\.\\",
36
+
37
+ # Sensitive data patterns
38
+ r"(?i)(password|token|key|secret|auth|credential)",
39
+ r"\b\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}\b", # IP addresses
40
+
41
+ # Excessive length or repetition
42
+ r".{10000,}", # Very long inputs
43
+ r"(.)\1{50,}", # Repeated characters
44
+
45
+ # Admin/privilege patterns
46
+ r"(?i)(admin|root|sudo|su -|chmod|chown)"
47
  ]
48
+
49
+ # Rate limiting configuration
50
+ self.rate_limit_config = {
51
+ "default": {"requests_per_minute": 60, "burst_capacity": 10},
52
+ "anonymous": {"requests_per_minute": 30, "burst_capacity": 5},
53
+ "suspicious": {"requests_per_minute": 10, "burst_capacity": 2}
54
+ }
55
+
56
  self.setup_logging()
57
 
58
  def setup_logging(self):
59
+ """Setup security logging"""
60
  self.logger = logging.getLogger(__name__)
61
+ self.logger.setLevel(logging.INFO)
62
 
63
+ def check_request(self, query: str, user_id: str, ip_address: Optional[str] = None) -> Dict[str, any]:
64
+ """
65
+ Comprehensive security check for incoming requests.
66
+
67
+ Args:
68
+ query: User's query text
69
+ user_id: User identifier
70
+ ip_address: Optional IP address for IP-based checks
71
+
72
+ Returns:
73
+ Security assessment result
74
+ """
75
  result = {
76
  "is_suspicious": False,
77
  "alerts": [],
78
+ "risk_score": 0,
79
+ "allowed": True,
80
+ "rate_limit_info": {}
81
  }
82
 
83
+ # Rate limiting check
84
+ rate_limit_result = self.check_rate_limit(user_id, ip_address)
85
+ if not rate_limit_result["allowed"]:
86
  result["is_suspicious"] = True
87
+ result["allowed"] = False
88
  result["alerts"].append("Rate limit exceeded")
89
  result["risk_score"] = 100
90
+ result["rate_limit_info"] = rate_limit_result
91
  return result
92
 
93
+ result["rate_limit_info"] = rate_limit_result
94
+
95
+ # Input validation and pattern matching
96
+ validation_result = self.validate_input(query, user_id)
97
+ result["alerts"].extend(validation_result["alerts"])
98
+ result["risk_score"] += validation_result["risk_score"]
99
+
100
+ # IP reputation check (if IP provided)
101
+ if ip_address:
102
+ ip_result = self.check_ip_reputation(ip_address)
103
+ result["alerts"].extend(ip_result["alerts"])
104
+ result["risk_score"] += ip_result["risk_score"]
105
+
106
+ # Determine overall suspicion
107
+ if result["risk_score"] >= 50:
108
+ result["is_suspicious"] = True
109
+ if result["risk_score"] >= 80:
110
+ result["allowed"] = False
111
+
112
+ # Log security event
113
+ self.log_security_event(user_id, ip_address, query, result)
114
+
115
+ return result
116
+
117
+ def check_rate_limit(self, user_id: str, ip_address: Optional[str] = None) -> Dict[str, any]:
118
+ """Check rate limits for user and/or IP"""
119
+ current_time = datetime.now()
120
+ user_key = f"user_{user_id}"
121
+ ip_key = f"ip_{ip_address}" if ip_address else None
122
+
123
+ # Get rate limit configuration
124
+ user_config = self.rate_limit_config.get("default")
125
+ if user_id == "anonymous":
126
+ user_config = self.rate_limit_config.get("anonymous", user_config)
127
+
128
+ # Check if user is marked as suspicious
129
+ if self.is_suspicious_user(user_id) or (ip_address and self.is_suspicious_ip(ip_address)):
130
+ user_config = self.rate_limit_config.get("suspicious", user_config)
131
+
132
+ # Clean old entries for user
133
+ self.rate_limits[user_key] = [
134
+ t for t in self.rate_limits.get(user_key, [])
135
+ if current_time - t < timedelta(minutes=1)
136
+ ]
137
+
138
+ # Clean old entries for IP (if provided)
139
+ if ip_key:
140
+ self.rate_limits[ip_key] = [
141
+ t for t in self.rate_limits.get(ip_key, [])
142
+ if current_time - t < timedelta(minutes=1)
143
+ ]
144
+
145
+ # Check user rate limit
146
+ user_requests = len(self.rate_limits.get(user_key, []))
147
+ user_allowed = user_requests < user_config["requests_per_minute"]
148
+
149
+ # Check IP rate limit (if IP provided)
150
+ ip_allowed = True
151
+ if ip_key:
152
+ ip_requests = len(self.rate_limits.get(ip_key, []))
153
+ ip_allowed = ip_requests < user_config["requests_per_minute"]
154
+
155
+ allowed = user_allowed and ip_allowed
156
+
157
+ # Add current request to counters if allowed
158
+ if allowed:
159
+ self.rate_limits.setdefault(user_key, []).append(current_time)
160
+ if ip_key:
161
+ self.rate_limits.setdefault(ip_key, []).append(current_time)
162
+
163
+ return {
164
+ "allowed": allowed,
165
+ "user_requests": user_requests,
166
+ "user_limit": user_config["requests_per_minute"],
167
+ "ip_requests": len(self.rate_limits.get(ip_key, [])) if ip_key else 0,
168
+ "ip_limit": user_config["requests_per_minute"] if ip_key else "N/A",
169
+ "retry_after": 60 if not allowed else 0
170
+ }
171
+
172
+ def validate_input(self, query: str, user_id: str) -> Dict[str, any]:
173
+ """Validate and analyze user input"""
174
+ result = {
175
+ "alerts": [],
176
+ "risk_score": 0
177
+ }
178
+
179
  # Pattern matching
180
  for pattern in self.suspicious_patterns:
181
+ matches = re.findall(pattern, query)
182
+ if matches:
183
+ alert_msg = f"Suspicious pattern detected: {pattern[:50]}..."
184
+ result["alerts"].append(alert_msg)
185
  result["risk_score"] += 20
186
 
187
  # Query length analysis
188
+ query_length = len(query)
189
+ if query_length > 10000:
190
+ result["alerts"].append("Excessively long query detected")
191
  result["risk_score"] += 30
192
+ elif query_length > 5000:
193
+ result["alerts"].append("Very long query detected")
194
+ result["risk_score"] += 15
195
 
196
  # Special character analysis
197
+ special_chars = len(re.findall(r'[^\w\s\.\?\!]', query))
198
+ special_char_ratio = special_chars / max(len(query), 1)
199
+
200
+ if special_char_ratio > 0.3:
201
  result["alerts"].append("High percentage of special characters")
202
  result["risk_score"] += 25
203
+ elif special_char_ratio > 0.2:
204
+ result["alerts"].append("Elevated special character usage")
205
+ result["risk_score"] += 10
206
 
207
+ # Entropy analysis (for encrypted/encoded content)
208
+ entropy = self.calculate_entropy(query)
209
+ if entropy > 6.0: # High entropy might indicate encoded/encrypted content
210
+ result["alerts"].append("High entropy content detected")
211
+ result["risk_score"] += 20
212
 
213
  return result
214
 
215
+ def check_ip_reputation(self, ip_address: str) -> Dict[str, any]:
216
+ """Check IP reputation (basic implementation)"""
217
+ result = {
218
+ "alerts": [],
219
+ "risk_score": 0
220
+ }
221
 
222
+ # Check if IP is in suspicious list
223
+ if self.is_suspicious_ip(ip_address):
224
+ result["alerts"].append("IP address has suspicious history")
225
+ result["risk_score"] += 40
226
+
227
+ # Simple IP pattern check (private IPs, localhost, etc.)
228
+ if ip_address in ["127.0.0.1", "localhost", "0.0.0.0"]:
229
+ result["alerts"].append("Local IP address detected")
230
+ result["risk_score"] += 10
231
+
232
+ # Check for rapid requests from this IP
233
+ ip_key = f"ip_{ip_address}"
234
+ recent_requests = len(self.rate_limits.get(ip_key, []))
235
+ if recent_requests > 50: # High volume from single IP
236
+ result["alerts"].append("High request volume from IP")
237
+ result["risk_score"] += 15
238
+
239
+ return result
240
+
241
+ def calculate_entropy(self, text: str) -> float:
242
+ """Calculate Shannon entropy of text (for detecting encoded content)"""
243
+ if not text:
244
+ return 0.0
245
+
246
+ import math
247
+ entropy = 0.0
248
+ text_length = len(text)
249
+
250
+ for char in set(text):
251
+ p_x = float(text.count(char)) / text_length
252
+ if p_x > 0:
253
+ entropy += - p_x * math.log2(p_x)
254
+
255
+ return entropy
256
+
257
+ def is_suspicious_user(self, user_id: str) -> bool:
258
+ """Check if user is marked as suspicious"""
259
+ # In a real implementation, this would check a database
260
+ # For now, use simple in-memory tracking
261
+ user_key = f"user_{user_id}"
262
+ return self.suspicious_ips.get(user_key, 0) > 5
263
+
264
+ def is_suspicious_ip(self, ip_address: str) -> bool:
265
+ """Check if IP is marked as suspicious"""
266
+ ip_key = f"ip_{ip_address}"
267
+ return self.suspicious_ips.get(ip_key, 0) > 3
268
+
269
+ def mark_suspicious(self, user_id: str, ip_address: Optional[str] = None, reason: str = ""):
270
+ """Mark user or IP as suspicious"""
271
+ if user_id:
272
+ user_key = f"user_{user_id}"
273
+ self.suspicious_ips[user_key] = self.suspicious_ips.get(user_key, 0) + 1
274
+
275
+ if ip_address:
276
+ ip_key = f"ip_{ip_address}"
277
+ self.suspicious_ips[ip_key] = self.suspicious_ips.get(ip_key, 0) + 1
278
+
279
+ self.logger.warning(f"Marked as suspicious - User: {user_id}, IP: {ip_address}, Reason: {reason}")
280
+
281
+ def log_security_event(self, user_id: str, ip_address: Optional[str], query: str, result: Dict):
282
+ """Log security event for auditing"""
283
+ event = {
284
+ "timestamp": datetime.now().isoformat(),
285
+ "user_id": user_id,
286
+ "ip_address": ip_address,
287
+ "query_preview": query[:100] + "..." if len(query) > 100 else query,
288
+ "query_length": len(query),
289
+ "risk_score": result["risk_score"],
290
+ "alerts": result["alerts"],
291
+ "allowed": result["allowed"],
292
+ "is_suspicious": result["is_suspicious"]
293
+ }
294
+
295
+ self.security_log.append(event)
296
+
297
+ # Keep only last 1000 events
298
+ if len(self.security_log) > 1000:
299
+ self.security_log = self.security_log[-1000:]
300
+
301
+ # Log to security logger if high risk
302
+ if result["risk_score"] >= 50:
303
+ self.logger.warning(f"Security alert: User {user_id} - Score: {result['risk_score']} - Alerts: {result['alerts']}")
304
+
305
+ def get_security_stats(self) -> Dict[str, any]:
306
+ """Get security statistics"""
307
+ recent_events = [e for e in self.security_log
308
+ if datetime.now() - datetime.fromisoformat(e["timestamp"]) < timedelta(hours=24)]
309
 
310
+ blocked_events = [e for e in recent_events if not e["allowed"]]
311
+ suspicious_events = [e for e in recent_events if e["is_suspicious"]]
 
312
 
313
+ return {
314
+ "total_events_24h": len(recent_events),
315
+ "blocked_requests_24h": len(blocked_events),
316
+ "suspicious_requests_24h": len(suspicious_events),
317
+ "current_suspicious_users": len([k for k, v in self.suspicious_ips.items() if k.startswith("user_") and v > 0]),
318
+ "current_suspicious_ips": len([k for k, v in self.suspicious_ips.items() if k.startswith("ip_") and v > 0]),
319
+ "rate_limits_tracked": len(self.rate_limits)
320
+ }
321
 
322
  def sanitize_input(self, text: str) -> str:
323
+ """Sanitize user input to prevent injection attacks"""
324
+ if not text:
325
+ return ""
326
+
327
  # Remove potentially dangerous characters
328
+ sanitized = re.sub(r'[<>"\']', '', text)
329
+
330
+ # Remove SQL injection patterns
331
+ sanitized = re.sub(r'(\b)(DROP|DELETE|INSERT|UPDATE|ALTER|EXEC)(\b)', '', sanitized, flags=re.IGNORECASE)
332
+
333
+ # Remove JavaScript and HTML patterns
334
+ sanitized = re.sub(r'(javascript|script|onload|onerror|onclick)', '', sanitized, flags=re.IGNORECASE)
335
+
336
+ # Remove command injection patterns
337
+ sanitized = re.sub(r'[;&|`]\s*\w+', '', sanitized)
338
+
339
+ return sanitized.strip()
340
+
341
+ def reset_rate_limits(self, user_id: Optional[str] = None, ip_address: Optional[str] = None):
342
+ """Reset rate limits for specific user or IP"""
343
+ if user_id:
344
+ user_key = f"user_{user_id}"
345
+ if user_key in self.rate_limits:
346
+ del self.rate_limits[user_key]
347
+
348
+ if ip_address:
349
+ ip_key = f"ip_{ip_address}"
350
+ if ip_key in self.rate_limits:
351
+ del self.rate_limits[ip_key]