DeWitt Gibson commited on
Commit
4a3bb46
·
1 Parent(s): 244340f

Adding Defender code

Browse files
src/llmguardian/defenders/__init__.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ defenders/__init__.py - Security defenders initialization
3
+ """
4
+
5
+ from .input_sanitizer import InputSanitizer
6
+ from .output_validator import OutputValidator
7
+ from .token_validator import TokenValidator
8
+ from .content_filter import ContentFilter
9
+ from .context_validator import ContextValidator
10
+
11
+ __all__ = [
12
+ 'InputSanitizer',
13
+ 'OutputValidator',
14
+ 'TokenValidator',
15
+ 'ContentFilter',
16
+ 'ContextValidator',
17
+ ]
src/llmguardian/defenders/content_filter.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ defenders/content_filter.py - Content filtering and moderation
3
+ """
4
+
5
+ import re
6
+ from typing import Dict, List, Optional, Any, Set
7
+ from dataclasses import dataclass
8
+ from enum import Enum
9
+ from ..core.logger import SecurityLogger
10
+ from ..core.exceptions import ValidationError
11
+
12
+ class ContentCategory(Enum):
13
+ MALICIOUS = "malicious"
14
+ SENSITIVE = "sensitive"
15
+ HARMFUL = "harmful"
16
+ INAPPROPRIATE = "inappropriate"
17
+ POTENTIAL_EXPLOIT = "potential_exploit"
18
+
19
+ @dataclass
20
+ class FilterRule:
21
+ pattern: str
22
+ category: ContentCategory
23
+ severity: int # 1-10
24
+ description: str
25
+ action: str # "block" or "sanitize"
26
+ replacement: str = "[FILTERED]"
27
+
28
+ @dataclass
29
+ class FilterResult:
30
+ is_allowed: bool
31
+ filtered_content: str
32
+ matched_rules: List[str]
33
+ risk_score: int
34
+ categories: Set[ContentCategory]
35
+ details: Dict[str, Any]
36
+
37
+ class ContentFilter:
38
+ def __init__(self, security_logger: Optional[SecurityLogger] = None):
39
+ self.security_logger = security_logger
40
+ self.rules = self._initialize_rules()
41
+ self.compiled_rules = {
42
+ name: re.compile(rule.pattern, re.IGNORECASE | re.MULTILINE)
43
+ for name, rule in self.rules.items()
44
+ }
45
+
46
+ def _initialize_rules(self) -> Dict[str, FilterRule]:
47
+ return {
48
+ "code_execution": FilterRule(
49
+ pattern=r"(?:exec|eval|system|subprocess|os\.)",
50
+ category=ContentCategory.MALICIOUS,
51
+ severity=9,
52
+ description="Code execution attempt",
53
+ action="block"
54
+ ),
55
+ "sql_commands": FilterRule(
56
+ pattern=r"(?:SELECT|INSERT|UPDATE|DELETE|DROP|UNION)\s+(?:FROM|INTO|TABLE)",
57
+ category=ContentCategory.MALICIOUS,
58
+ severity=8,
59
+ description="SQL command",
60
+ action="block"
61
+ ),
62
+ "file_operations": FilterRule(
63
+ pattern=r"(?:read|write|open|delete|remove)\s*\(['\"].*?['\"]",
64
+ category=ContentCategory.POTENTIAL_EXPLOIT,
65
+ severity=7,
66
+ description="File operation",
67
+ action="block"
68
+ ),
69
+ "pii_data": FilterRule(
70
+ pattern=r"\b\d{3}-\d{2}-\d{4}\b|\b\d{16}\b",
71
+ category=ContentCategory.SENSITIVE,
72
+ severity=8,
73
+ description="PII data",
74
+ action="sanitize",
75
+ replacement="[REDACTED]"
76
+ ),
77
+ "harmful_content": FilterRule(
78
+ pattern=r"(?:hack|exploit|bypass|vulnerability)\s+(?:system|security|protection)",
79
+ category=ContentCategory.HARMFUL,
80
+ severity=7,
81
+ description="Potentially harmful content",
82
+ action="block"
83
+ ),
84
+ "inappropriate_content": FilterRule(
85
+ pattern=r"(?:explicit|offensive|inappropriate).*content",
86
+ category=ContentCategory.INAPPROPRIATE,
87
+ severity=6,
88
+ description="Inappropriate content",
89
+ action="sanitize"
90
+ ),
91
+ }
92
+
93
+ def filter_content(self, content: str, context: Optional[Dict[str, Any]] = None) -> FilterResult:
94
+ try:
95
+ matched_rules = []
96
+ categories = set()
97
+ risk_score = 0
98
+ filtered = content
99
+ is_allowed = True
100
+
101
+ for name, rule in self.rules.items():
102
+ pattern = self.compiled_rules[name]
103
+ matches = pattern.findall(filtered)
104
+
105
+ if matches:
106
+ matched_rules.append(name)
107
+ categories.add(rule.category)
108
+ risk_score = max(risk_score, rule.severity)
109
+
110
+ if rule.action == "block":
111
+ is_allowed = False
112
+ elif rule.action == "sanitize":
113
+ filtered = pattern.sub(rule.replacement, filtered)
114
+
115
+ result = FilterResult(
116
+ is_allowed=is_allowed,
117
+ filtered_content=filtered if is_allowed else "[CONTENT BLOCKED]",
118
+ matched_rules=matched_rules,
119
+ risk_score=risk_score,
120
+ categories=categories,
121
+ details={
122
+ "original_length": len(content),
123
+ "filtered_length": len(filtered),
124
+ "rule_matches": len(matched_rules),
125
+ "context": context or {}
126
+ }
127
+ )
128
+
129
+ if matched_rules and self.security_logger:
130
+ self.security_logger.log_security_event(
131
+ "content_filtered",
132
+ matched_rules=matched_rules,
133
+ categories=[c.value for c in categories],
134
+ risk_score=risk_score,
135
+ is_allowed=is_allowed
136
+ )
137
+
138
+ return result
139
+
140
+ except Exception as e:
141
+ if self.security_logger:
142
+ self.security_logger.log_security_event(
143
+ "filter_error",
144
+ error=str(e),
145
+ content_length=len(content)
146
+ )
147
+ raise ValidationError(f"Content filtering failed: {str(e)}")
148
+
149
+ def add_rule(self, name: str, rule: FilterRule) -> None:
150
+ self.rules[name] = rule
151
+ self.compiled_rules[name] = re.compile(rule.pattern, re.IGNORECASE | re.MULTILINE)
152
+
153
+ def remove_rule(self, name: str) -> None:
154
+ self.rules.pop(name, None)
155
+ self.compiled_rules.pop(name, None)
156
+
157
+ def get_rules(self) -> Dict[str, Dict[str, Any]]:
158
+ return {
159
+ name: {
160
+ "pattern": rule.pattern,
161
+ "category": rule.category.value,
162
+ "severity": rule.severity,
163
+ "description": rule.description,
164
+ "action": rule.action
165
+ }
166
+ for name, rule in self.rules.items()
167
+ }
src/llmguardian/defenders/context_validator.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ defenders/context_validator.py - Context validation for LLM interactions
3
+ """
4
+
5
+ from typing import Dict, Optional, List, Any
6
+ from dataclasses import dataclass
7
+ from datetime import datetime
8
+ import hashlib
9
+ from ..core.logger import SecurityLogger
10
+ from ..core.exceptions import ValidationError
11
+
12
+ @dataclass
13
+ class ContextRule:
14
+ max_age: int # seconds
15
+ required_fields: List[str]
16
+ forbidden_fields: List[str]
17
+ max_depth: int
18
+ checksum_fields: List[str]
19
+
20
+ @dataclass
21
+ class ValidationResult:
22
+ is_valid: bool
23
+ errors: List[str]
24
+ modified_context: Dict[str, Any]
25
+ metadata: Dict[str, Any]
26
+
27
+ class ContextValidator:
28
+ def __init__(self, security_logger: Optional[SecurityLogger] = None):
29
+ self.security_logger = security_logger
30
+ self.rule = ContextRule(
31
+ max_age=3600,
32
+ required_fields=["user_id", "session_id", "timestamp"],
33
+ forbidden_fields=["password", "secret", "token"],
34
+ max_depth=5,
35
+ checksum_fields=["user_id", "session_id"]
36
+ )
37
+
38
+ def validate_context(self, context: Dict[str, Any], previous_context: Optional[Dict[str, Any]] = None) -> ValidationResult:
39
+ try:
40
+ errors = []
41
+ modified = context.copy()
42
+
43
+ # Check required fields
44
+ missing = [f for f in self.rule.required_fields if f not in context]
45
+ if missing:
46
+ errors.append(f"Missing required fields: {missing}")
47
+
48
+ # Check forbidden fields
49
+ forbidden = [f for f in self.rule.forbidden_fields if f in context]
50
+ if forbidden:
51
+ errors.append(f"Forbidden fields present: {forbidden}")
52
+ for field in forbidden:
53
+ modified.pop(field, None)
54
+
55
+ # Validate timestamp
56
+ if "timestamp" in context:
57
+ age = (datetime.utcnow() - datetime.fromisoformat(str(context["timestamp"]))).seconds
58
+ if age > self.rule.max_age:
59
+ errors.append(f"Context too old: {age} seconds")
60
+
61
+ # Check context depth
62
+ if not self._check_depth(context, 0):
63
+ errors.append(f"Context exceeds max depth of {self.rule.max_depth}")
64
+
65
+ # Verify checksums if previous context exists
66
+ if previous_context:
67
+ if not self._verify_checksums(context, previous_context):
68
+ errors.append("Context checksum mismatch")
69
+
70
+ # Build metadata
71
+ metadata = {
72
+ "validation_time": datetime.utcnow().isoformat(),
73
+ "original_size": len(str(context)),
74
+ "modified_size": len(str(modified)),
75
+ "changes": len(errors)
76
+ }
77
+
78
+ result = ValidationResult(
79
+ is_valid=len(errors) == 0,
80
+ errors=errors,
81
+ modified_context=modified,
82
+ metadata=metadata
83
+ )
84
+
85
+ if errors and self.security_logger:
86
+ self.security_logger.log_security_event(
87
+ "context_validation_failure",
88
+ errors=errors,
89
+ context_id=context.get("context_id")
90
+ )
91
+
92
+ return result
93
+
94
+ except Exception as e:
95
+ if self.security_logger:
96
+ self.security_logger.log_security_event(
97
+ "context_validation_error",
98
+ error=str(e)
99
+ )
100
+ raise ValidationError(f"Context validation failed: {str(e)}")
101
+
102
+ def _check_depth(self, obj: Any, depth: int) -> bool:
103
+ if depth > self.rule.max_depth:
104
+ return False
105
+ if isinstance(obj, dict):
106
+ return all(self._check_depth(v, depth + 1) for v in obj.values())
107
+ if isinstance(obj, list):
108
+ return all(self._check_depth(v, depth + 1) for v in obj)
109
+ return True
110
+
111
+ def _verify_checksums(self, current: Dict[str, Any], previous: Dict[str, Any]) -> bool:
112
+ for field in self.rule.checksum_fields:
113
+ if field in current and field in previous:
114
+ current_hash = hashlib.sha256(str(current[field]).encode()).hexdigest()
115
+ previous_hash = hashlib.sha256(str(previous[field]).encode()).hexdigest()
116
+ if current_hash != previous_hash:
117
+ return False
118
+ return True
119
+
120
+ def update_rule(self, updates: Dict[str, Any]) -> None:
121
+ for key, value in updates.items():
122
+ if hasattr(self.rule, key):
123
+ setattr(self.rule, key, value)
src/llmguardian/defenders/input_sanitizer.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ defenders/input_sanitizer.py - Input sanitization for LLM inputs
3
+ """
4
+
5
+ import re
6
+ from typing import Dict, Any, List, Optional
7
+ from dataclasses import dataclass
8
+ from ..core.logger import SecurityLogger
9
+ from ..core.exceptions import ValidationError
10
+
11
+ @dataclass
12
+ class SanitizationRule:
13
+ pattern: str
14
+ replacement: str
15
+ description: str
16
+ enabled: bool = True
17
+
18
+ @dataclass
19
+ class SanitizationResult:
20
+ original: str
21
+ sanitized: str
22
+ applied_rules: List[str]
23
+ is_modified: bool
24
+ risk_level: str
25
+
26
+ class InputSanitizer:
27
+ def __init__(self, security_logger: Optional[SecurityLogger] = None):
28
+ self.security_logger = security_logger
29
+ self.rules = self._initialize_rules()
30
+ self.compiled_rules = {
31
+ name: re.compile(rule.pattern, re.IGNORECASE | re.MULTILINE)
32
+ for name, rule in self.rules.items()
33
+ if rule.enabled
34
+ }
35
+
36
+ def _initialize_rules(self) -> Dict[str, SanitizationRule]:
37
+ return {
38
+ "system_instructions": SanitizationRule(
39
+ pattern=r"system:\s*|instruction:\s*",
40
+ replacement=" ",
41
+ description="Remove system instruction markers"
42
+ ),
43
+ "code_injection": SanitizationRule(
44
+ pattern=r"<script.*?>.*?</script>",
45
+ replacement="",
46
+ description="Remove script tags"
47
+ ),
48
+ "delimiter_injection": SanitizationRule(
49
+ pattern=r"[<\[{](?:system|prompt|instruction)[>\]}]",
50
+ replacement="",
51
+ description="Remove delimiter-based injections"
52
+ ),
53
+ "command_injection": SanitizationRule(
54
+ pattern=r"(?:exec|eval|system)\s*\(",
55
+ replacement="",
56
+ description="Remove command execution attempts"
57
+ ),
58
+ "encoding_patterns": SanitizationRule(
59
+ pattern=r"(?:base64|hex|rot13)\s*\(",
60
+ replacement="",
61
+ description="Remove encoding attempts"
62
+ ),
63
+ }
64
+
65
+ def sanitize(self, input_text: str, context: Optional[Dict[str, Any]] = None) -> SanitizationResult:
66
+ original = input_text
67
+ applied_rules = []
68
+ is_modified = False
69
+
70
+ try:
71
+ sanitized = input_text
72
+ for name, rule in self.rules.items():
73
+ if not rule.enabled:
74
+ continue
75
+
76
+ pattern = self.compiled_rules.get(name)
77
+ if not pattern:
78
+ continue
79
+
80
+ new_text = pattern.sub(rule.replacement, sanitized)
81
+ if new_text != sanitized:
82
+ applied_rules.append(name)
83
+ is_modified = True
84
+ sanitized = new_text
85
+
86
+ risk_level = self._assess_risk(applied_rules)
87
+
88
+ if is_modified and self.security_logger:
89
+ self.security_logger.log_security_event(
90
+ "input_sanitization",
91
+ original_length=len(original),
92
+ sanitized_length=len(sanitized),
93
+ applied_rules=applied_rules,
94
+ risk_level=risk_level
95
+ )
96
+
97
+ return SanitizationResult(
98
+ original=original,
99
+ sanitized=sanitized,
100
+ applied_rules=applied_rules,
101
+ is_modified=is_modified,
102
+ risk_level=risk_level
103
+ )
104
+
105
+ except Exception as e:
106
+ if self.security_logger:
107
+ self.security_logger.log_security_event(
108
+ "sanitization_error",
109
+ error=str(e),
110
+ input_length=len(input_text)
111
+ )
112
+ raise ValidationError(f"Sanitization failed: {str(e)}")
113
+
114
+ def _assess_risk(self, applied_rules: List[str]) -> str:
115
+ if not applied_rules:
116
+ return "low"
117
+ if len(applied_rules) >= 3:
118
+ return "high"
119
+ if "command_injection" in applied_rules or "code_injection" in applied_rules:
120
+ return "high"
121
+ return "medium"
122
+
123
+ def add_rule(self, name: str, rule: SanitizationRule) -> None:
124
+ self.rules[name] = rule
125
+ if rule.enabled:
126
+ self.compiled_rules[name] = re.compile(rule.pattern, re.IGNORECASE | re.MULTILINE)
127
+
128
+ def remove_rule(self, name: str) -> None:
129
+ self.rules.pop(name, None)
130
+ self.compiled_rules.pop(name, None)
131
+
132
+ def get_rules(self) -> Dict[str, Dict[str, Any]]:
133
+ return {
134
+ name: {
135
+ "pattern": rule.pattern,
136
+ "replacement": rule.replacement,
137
+ "description": rule.description,
138
+ "enabled": rule.enabled
139
+ }
140
+ for name, rule in self.rules.items()
141
+ }
src/llmguardian/defenders/output_validator.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ defenders/output_validator.py - Output validation and sanitization
3
+ """
4
+
5
+ import re
6
+ from typing import Dict, List, Optional, Set, Any
7
+ from dataclasses import dataclass
8
+ from ..core.logger import SecurityLogger
9
+ from ..core.exceptions import ValidationError
10
+
11
+ @dataclass
12
+ class ValidationRule:
13
+ pattern: str
14
+ description: str
15
+ severity: int # 1-10
16
+ block: bool = True
17
+ sanitize: bool = True
18
+ replacement: str = ""
19
+
20
+ @dataclass
21
+ class ValidationResult:
22
+ is_valid: bool
23
+ violations: List[str]
24
+ sanitized_output: Optional[str]
25
+ risk_score: int
26
+ details: Dict[str, Any]
27
+
28
+ class OutputValidator:
29
+ def __init__(self, security_logger: Optional[SecurityLogger] = None):
30
+ self.security_logger = security_logger
31
+ self.rules = self._initialize_rules()
32
+ self.compiled_rules = {
33
+ name: re.compile(rule.pattern, re.IGNORECASE | re.MULTILINE)
34
+ for name, rule in self.rules.items()
35
+ }
36
+ self.sensitive_patterns = self._initialize_sensitive_patterns()
37
+
38
+ def _initialize_rules(self) -> Dict[str, ValidationRule]:
39
+ return {
40
+ "sql_injection": ValidationRule(
41
+ pattern=r"(?:SELECT|INSERT|UPDATE|DELETE)\s+(?:FROM|INTO)\s+\w+",
42
+ description="SQL query in output",
43
+ severity=9,
44
+ block=True
45
+ ),
46
+ "code_injection": ValidationRule(
47
+ pattern=r"<script.*?>.*?</script>",
48
+ description="JavaScript code in output",
49
+ severity=8,
50
+ block=True
51
+ ),
52
+ "system_info": ValidationRule(
53
+ pattern=r"(?:system|config|env|secret)(?:_|\s+)?(?:key|token|password)",
54
+ description="System information leak",
55
+ severity=9,
56
+ block=True
57
+ ),
58
+ "personal_data": ValidationRule(
59
+ pattern=r"\b\d{3}-\d{2}-\d{4}\b|\b\d{16}\b",
60
+ description="Personal data (SSN/CC)",
61
+ severity=10,
62
+ block=True
63
+ ),
64
+ "file_paths": ValidationRule(
65
+ pattern=r"(?:/[\w./]+)|(?:C:\\[\w\\]+)",
66
+ description="File system paths",
67
+ severity=7,
68
+ block=True
69
+ ),
70
+ "html_content": ValidationRule(
71
+ pattern=r"<(?!br|p|b|i|em|strong)[^>]+>",
72
+ description="HTML content",
73
+ severity=6,
74
+ sanitize=True,
75
+ replacement=""
76
+ ),
77
+ }
78
+
79
+ def _initialize_sensitive_patterns(self) -> Set[str]:
80
+ return {
81
+ r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b", # Email
82
+ r"\b\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}\b", # IP address
83
+ r"(?i)api[_-]?key", # API keys
84
+ r"(?i)password|passwd|pwd", # Passwords
85
+ r"(?i)token|secret|credential", # Credentials
86
+ r"\b[A-Z0-9]{20,}\b", # Long alphanumeric strings
87
+ }
88
+
89
+ def validate(self, output: str, context: Optional[Dict[str, Any]] = None) -> ValidationResult:
90
+ try:
91
+ violations = []
92
+ risk_score = 0
93
+ sanitized = output
94
+ is_valid = True
95
+
96
+ # Check against validation rules
97
+ for name, rule in self.rules.items():
98
+ pattern = self.compiled_rules[name]
99
+ matches = pattern.findall(sanitized)
100
+
101
+ if matches:
102
+ violations.append(f"{name}: {rule.description}")
103
+ risk_score = max(risk_score, rule.severity)
104
+
105
+ if rule.block:
106
+ is_valid = False
107
+
108
+ if rule.sanitize:
109
+ sanitized = pattern.sub(rule.replacement, sanitized)
110
+
111
+ # Check for sensitive data patterns
112
+ for pattern in self.sensitive_patterns:
113
+ matches = re.findall(pattern, sanitized)
114
+ if matches:
115
+ violations.append(f"Sensitive data detected: {pattern}")
116
+ risk_score = max(risk_score, 8)
117
+ is_valid = False
118
+ sanitized = re.sub(pattern, "[REDACTED]", sanitized)
119
+
120
+ result = ValidationResult(
121
+ is_valid=is_valid,
122
+ violations=violations,
123
+ sanitized_output=sanitized if violations else output,
124
+ risk_score=risk_score,
125
+ details={
126
+ "original_length": len(output),
127
+ "sanitized_length": len(sanitized),
128
+ "violation_count": len(violations),
129
+ "context": context or {}
130
+ }
131
+ )
132
+
133
+ if violations and self.security_logger:
134
+ self.security_logger.log_security_event(
135
+ "output_validation",
136
+ violations=violations,
137
+ risk_score=risk_score,
138
+ is_valid=is_valid
139
+ )
140
+
141
+ return result
142
+
143
+ except Exception as e:
144
+ if self.security_logger:
145
+ self.security_logger.log_security_event(
146
+ "validation_error",
147
+ error=str(e),
148
+ output_length=len(output)
149
+ )
150
+ raise ValidationError(f"Output validation failed: {str(e)}")
151
+
152
+ def add_rule(self, name: str, rule: ValidationRule) -> None:
153
+ self.rules[name] = rule
154
+ self.compiled_rules[name] = re.compile(rule.pattern, re.IGNORECASE | re.MULTILINE)
155
+
156
+ def remove_rule(self, name: str) -> None:
157
+ self.rules.pop(name, None)
158
+ self.compiled_rules.pop(name, None)
159
+
160
+ def add_sensitive_pattern(self, pattern: str) -> None:
161
+ self.sensitive_patterns.add(pattern)
162
+
163
+ def get_rules(self) -> Dict[str, Dict[str, Any]]:
164
+ return {
165
+ name: {
166
+ "pattern": rule.pattern,
167
+ "description": rule.description,
168
+ "severity": rule.severity,
169
+ "block": rule.block,
170
+ "sanitize": rule.sanitize
171
+ }
172
+ for name, rule in self.rules.items()
173
+ }
src/llmguardian/defenders/test_context_validator.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ tests/defenders/test_context_validator.py - Tests for context validation
3
+ """
4
+
5
+ import pytest
6
+ from datetime import datetime, timedelta
7
+ from llmguardian.defenders.context_validator import ContextValidator, ValidationResult
8
+ from llmguardian.core.exceptions import ValidationError
9
+
10
+ @pytest.fixture
11
+ def validator():
12
+ return ContextValidator()
13
+
14
+ @pytest.fixture
15
+ def valid_context():
16
+ return {
17
+ "user_id": "test_user",
18
+ "session_id": "test_session",
19
+ "timestamp": datetime.utcnow().isoformat(),
20
+ "request_id": "123",
21
+ "metadata": {
22
+ "source": "test",
23
+ "version": "1.0"
24
+ }
25
+ }
26
+
27
+ def test_valid_context(validator, valid_context):
28
+ result = validator.validate_context(valid_context)
29
+ assert result.is_valid
30
+ assert not result.errors
31
+ assert result.modified_context == valid_context
32
+
33
+ def test_missing_required_fields(validator):
34
+ context = {
35
+ "user_id": "test_user",
36
+ "timestamp": datetime.utcnow().isoformat()
37
+ }
38
+ result = validator.validate_context(context)
39
+ assert not result.is_valid
40
+ assert "Missing required fields" in result.errors[0]
41
+
42
+ def test_forbidden_fields(validator, valid_context):
43
+ context = valid_context.copy()
44
+ context["password"] = "secret123"
45
+ result = validator.validate_context(context)
46
+ assert not result.is_valid
47
+ assert "Forbidden fields present" in result.errors[0]
48
+ assert "password" not in result.modified_context
49
+
50
+ def test_context_age(validator, valid_context):
51
+ old_context = valid_context.copy()
52
+ old_context["timestamp"] = (
53
+ datetime.utcnow() - timedelta(hours=2)
54
+ ).isoformat()
55
+ result = validator.validate_context(old_context)
56
+ assert not result.is_valid
57
+ assert "Context too old" in result.errors[0]
58
+
59
+ def test_context_depth(validator, valid_context):
60
+ deep_context = valid_context.copy()
61
+ current = deep_context
62
+ for i in range(10):
63
+ current["nested"] = {}
64
+ current = current["nested"]
65
+ result = validator.validate_context(deep_context)
66
+ assert not result.is_valid
67
+ assert "Context exceeds max depth" in result.errors[0]
68
+
69
+ def test_checksum_verification(validator, valid_context):
70
+ previous_context = valid_context.copy()
71
+ modified_context = valid_context.copy()
72
+ modified_context["user_id"] = "different_user"
73
+ result = validator.validate_context(modified_context, previous_context)
74
+ assert not result.is_valid
75
+ assert "Context checksum mismatch" in result.errors[0]
76
+
77
+ def test_update_rule(validator):
78
+ validator.update_rule({"max_age": 7200})
79
+ old_context = {
80
+ "user_id": "test_user",
81
+ "session_id": "test_session",
82
+ "timestamp": (
83
+ datetime.utcnow() - timedelta(hours=1.5)
84
+ ).isoformat()
85
+ }
86
+ result = validator.validate_context(old_context)
87
+ assert result.is_valid
88
+
89
+ def test_exception_handling(validator):
90
+ with pytest.raises(ValidationError):
91
+ validator.validate_context({"timestamp": "invalid_date"})
92
+
93
+ def test_metadata_generation(validator, valid_context):
94
+ result = validator.validate_context(valid_context)
95
+ assert "validation_time" in result.metadata
96
+ assert "original_size" in result.metadata
97
+ assert "modified_size" in result.metadata
98
+ assert "changes" in result.metadata
src/llmguardian/defenders/token_validator.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ defenders/token_validator.py - Token and credential validation
3
+ """
4
+
5
+ from typing import Dict, Optional, Any, List
6
+ from dataclasses import dataclass
7
+ import re
8
+ import jwt
9
+ from datetime import datetime, timedelta
10
+ from ..core.logger import SecurityLogger
11
+ from ..core.exceptions import TokenValidationError
12
+
13
+ @dataclass
14
+ class TokenRule:
15
+ pattern: str
16
+ description: str
17
+ min_length: int
18
+ max_length: int
19
+ required_chars: str
20
+ expiry_time: int # in seconds
21
+
22
+ @dataclass
23
+ class TokenValidationResult:
24
+ is_valid: bool
25
+ errors: List[str]
26
+ metadata: Dict[str, Any]
27
+ expiry: Optional[datetime]
28
+
29
+ class TokenValidator:
30
+ def __init__(self, security_logger: Optional[SecurityLogger] = None):
31
+ self.security_logger = security_logger
32
+ self.rules = self._initialize_rules()
33
+ self.secret_key = self._load_secret_key()
34
+
35
+ def _initialize_rules(self) -> Dict[str, TokenRule]:
36
+ return {
37
+ "jwt": TokenRule(
38
+ pattern=r"^[A-Za-z0-9-_=]+\.[A-Za-z0-9-_=]+\.[A-Za-z0-9-_.+/=]+$",
39
+ description="JWT token",
40
+ min_length=32,
41
+ max_length=4096,
42
+ required_chars=".-_",
43
+ expiry_time=3600
44
+ ),
45
+ "api_key": TokenRule(
46
+ pattern=r"^[A-Za-z0-9]{32,64}$",
47
+ description="API key",
48
+ min_length=32,
49
+ max_length=64,
50
+ required_chars="",
51
+ expiry_time=86400
52
+ ),
53
+ "session_token": TokenRule(
54
+ pattern=r"^[A-Fa-f0-9]{64}$",
55
+ description="Session token",
56
+ min_length=64,
57
+ max_length=64,
58
+ required_chars="",
59
+ expiry_time=7200
60
+ )
61
+ }
62
+
63
+ def _load_secret_key(self) -> bytes:
64
+ # Implementation would load from secure storage
65
+ return b"your-256-bit-secret"
66
+
67
+ def validate_token(self, token: str, token_type: str) -> TokenValidationResult:
68
+ try:
69
+ if token_type not in self.rules:
70
+ raise TokenValidationError(f"Unknown token type: {token_type}")
71
+
72
+ rule = self.rules[token_type]
73
+ errors = []
74
+ metadata = {}
75
+
76
+ # Length validation
77
+ if len(token) < rule.min_length or len(token) > rule.max_length:
78
+ errors.append(f"Token length must be between {rule.min_length} and {rule.max_length}")
79
+
80
+ # Pattern validation
81
+ if not re.match(rule.pattern, token):
82
+ errors.append("Token format is invalid")
83
+
84
+ # Required characters
85
+ if rule.required_chars:
86
+ missing_chars = set(rule.required_chars) - set(token)
87
+ if missing_chars:
88
+ errors.append(f"Token missing required characters: {missing_chars}")
89
+
90
+ # JWT-specific validation
91
+ if token_type == "jwt":
92
+ try:
93
+ payload = jwt.decode(token, self.secret_key, algorithms=["HS256"])
94
+ metadata = payload
95
+ exp = datetime.fromtimestamp(payload.get("exp", 0))
96
+ if exp < datetime.utcnow():
97
+ errors.append("Token has expired")
98
+ except jwt.InvalidTokenError as e:
99
+ errors.append(f"Invalid JWT: {str(e)}")
100
+
101
+ is_valid = len(errors) == 0
102
+ expiry = datetime.utcnow() + timedelta(seconds=rule.expiry_time)
103
+
104
+ if not is_valid and self.security_logger:
105
+ self.security_logger.log_security_event(
106
+ "token_validation_failure",
107
+ token_type=token_type,
108
+ errors=errors
109
+ )
110
+
111
+ return TokenValidationResult(
112
+ is_valid=is_valid,
113
+ errors=errors,
114
+ metadata=metadata,
115
+ expiry=expiry if is_valid else None
116
+ )
117
+
118
+ except Exception as e:
119
+ if self.security_logger:
120
+ self.security_logger.log_security_event(
121
+ "token_validation_error",
122
+ error=str(e)
123
+ )
124
+ raise TokenValidationError(f"Validation failed: {str(e)}")
125
+
126
+ def create_token(self, token_type: str, payload: Dict[str, Any]) -> str:
127
+ if token_type not in self.rules:
128
+ raise TokenValidationError(f"Unknown token type: {token_type}")
129
+
130
+ try:
131
+ if token_type == "jwt":
132
+ expiry = datetime.utcnow() + timedelta(
133
+ seconds=self.rules[token_type].expiry_time
134
+ )
135
+ payload["exp"] = expiry.timestamp()
136
+ return jwt.encode(payload, self.secret_key, algorithm="HS256")
137
+
138
+ # Add other token type creation logic here
139
+ raise TokenValidationError(f"Token creation not implemented for {token_type}")
140
+
141
+ except Exception as e:
142
+ if self.security_logger:
143
+ self.security_logger.log_security_event(
144
+ "token_creation_error",
145
+ error=str(e)
146
+ )
147
+ raise TokenValidationError(f"Token creation failed: {str(e)}")