Spaces:
Configuration error
Configuration error
DeWitt Gibson
commited on
Commit
·
4a3bb46
1
Parent(s):
244340f
Adding Defender code
Browse files- src/llmguardian/defenders/__init__.py +17 -0
- src/llmguardian/defenders/content_filter.py +167 -0
- src/llmguardian/defenders/context_validator.py +123 -0
- src/llmguardian/defenders/input_sanitizer.py +141 -0
- src/llmguardian/defenders/output_validator.py +173 -0
- src/llmguardian/defenders/test_context_validator.py +98 -0
- src/llmguardian/defenders/token_validator.py +147 -0
src/llmguardian/defenders/__init__.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
defenders/__init__.py - Security defenders initialization
|
3 |
+
"""
|
4 |
+
|
5 |
+
from .input_sanitizer import InputSanitizer
|
6 |
+
from .output_validator import OutputValidator
|
7 |
+
from .token_validator import TokenValidator
|
8 |
+
from .content_filter import ContentFilter
|
9 |
+
from .context_validator import ContextValidator
|
10 |
+
|
11 |
+
__all__ = [
|
12 |
+
'InputSanitizer',
|
13 |
+
'OutputValidator',
|
14 |
+
'TokenValidator',
|
15 |
+
'ContentFilter',
|
16 |
+
'ContextValidator',
|
17 |
+
]
|
src/llmguardian/defenders/content_filter.py
ADDED
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
defenders/content_filter.py - Content filtering and moderation
|
3 |
+
"""
|
4 |
+
|
5 |
+
import re
|
6 |
+
from typing import Dict, List, Optional, Any, Set
|
7 |
+
from dataclasses import dataclass
|
8 |
+
from enum import Enum
|
9 |
+
from ..core.logger import SecurityLogger
|
10 |
+
from ..core.exceptions import ValidationError
|
11 |
+
|
12 |
+
class ContentCategory(Enum):
|
13 |
+
MALICIOUS = "malicious"
|
14 |
+
SENSITIVE = "sensitive"
|
15 |
+
HARMFUL = "harmful"
|
16 |
+
INAPPROPRIATE = "inappropriate"
|
17 |
+
POTENTIAL_EXPLOIT = "potential_exploit"
|
18 |
+
|
19 |
+
@dataclass
|
20 |
+
class FilterRule:
|
21 |
+
pattern: str
|
22 |
+
category: ContentCategory
|
23 |
+
severity: int # 1-10
|
24 |
+
description: str
|
25 |
+
action: str # "block" or "sanitize"
|
26 |
+
replacement: str = "[FILTERED]"
|
27 |
+
|
28 |
+
@dataclass
|
29 |
+
class FilterResult:
|
30 |
+
is_allowed: bool
|
31 |
+
filtered_content: str
|
32 |
+
matched_rules: List[str]
|
33 |
+
risk_score: int
|
34 |
+
categories: Set[ContentCategory]
|
35 |
+
details: Dict[str, Any]
|
36 |
+
|
37 |
+
class ContentFilter:
|
38 |
+
def __init__(self, security_logger: Optional[SecurityLogger] = None):
|
39 |
+
self.security_logger = security_logger
|
40 |
+
self.rules = self._initialize_rules()
|
41 |
+
self.compiled_rules = {
|
42 |
+
name: re.compile(rule.pattern, re.IGNORECASE | re.MULTILINE)
|
43 |
+
for name, rule in self.rules.items()
|
44 |
+
}
|
45 |
+
|
46 |
+
def _initialize_rules(self) -> Dict[str, FilterRule]:
|
47 |
+
return {
|
48 |
+
"code_execution": FilterRule(
|
49 |
+
pattern=r"(?:exec|eval|system|subprocess|os\.)",
|
50 |
+
category=ContentCategory.MALICIOUS,
|
51 |
+
severity=9,
|
52 |
+
description="Code execution attempt",
|
53 |
+
action="block"
|
54 |
+
),
|
55 |
+
"sql_commands": FilterRule(
|
56 |
+
pattern=r"(?:SELECT|INSERT|UPDATE|DELETE|DROP|UNION)\s+(?:FROM|INTO|TABLE)",
|
57 |
+
category=ContentCategory.MALICIOUS,
|
58 |
+
severity=8,
|
59 |
+
description="SQL command",
|
60 |
+
action="block"
|
61 |
+
),
|
62 |
+
"file_operations": FilterRule(
|
63 |
+
pattern=r"(?:read|write|open|delete|remove)\s*\(['\"].*?['\"]",
|
64 |
+
category=ContentCategory.POTENTIAL_EXPLOIT,
|
65 |
+
severity=7,
|
66 |
+
description="File operation",
|
67 |
+
action="block"
|
68 |
+
),
|
69 |
+
"pii_data": FilterRule(
|
70 |
+
pattern=r"\b\d{3}-\d{2}-\d{4}\b|\b\d{16}\b",
|
71 |
+
category=ContentCategory.SENSITIVE,
|
72 |
+
severity=8,
|
73 |
+
description="PII data",
|
74 |
+
action="sanitize",
|
75 |
+
replacement="[REDACTED]"
|
76 |
+
),
|
77 |
+
"harmful_content": FilterRule(
|
78 |
+
pattern=r"(?:hack|exploit|bypass|vulnerability)\s+(?:system|security|protection)",
|
79 |
+
category=ContentCategory.HARMFUL,
|
80 |
+
severity=7,
|
81 |
+
description="Potentially harmful content",
|
82 |
+
action="block"
|
83 |
+
),
|
84 |
+
"inappropriate_content": FilterRule(
|
85 |
+
pattern=r"(?:explicit|offensive|inappropriate).*content",
|
86 |
+
category=ContentCategory.INAPPROPRIATE,
|
87 |
+
severity=6,
|
88 |
+
description="Inappropriate content",
|
89 |
+
action="sanitize"
|
90 |
+
),
|
91 |
+
}
|
92 |
+
|
93 |
+
def filter_content(self, content: str, context: Optional[Dict[str, Any]] = None) -> FilterResult:
|
94 |
+
try:
|
95 |
+
matched_rules = []
|
96 |
+
categories = set()
|
97 |
+
risk_score = 0
|
98 |
+
filtered = content
|
99 |
+
is_allowed = True
|
100 |
+
|
101 |
+
for name, rule in self.rules.items():
|
102 |
+
pattern = self.compiled_rules[name]
|
103 |
+
matches = pattern.findall(filtered)
|
104 |
+
|
105 |
+
if matches:
|
106 |
+
matched_rules.append(name)
|
107 |
+
categories.add(rule.category)
|
108 |
+
risk_score = max(risk_score, rule.severity)
|
109 |
+
|
110 |
+
if rule.action == "block":
|
111 |
+
is_allowed = False
|
112 |
+
elif rule.action == "sanitize":
|
113 |
+
filtered = pattern.sub(rule.replacement, filtered)
|
114 |
+
|
115 |
+
result = FilterResult(
|
116 |
+
is_allowed=is_allowed,
|
117 |
+
filtered_content=filtered if is_allowed else "[CONTENT BLOCKED]",
|
118 |
+
matched_rules=matched_rules,
|
119 |
+
risk_score=risk_score,
|
120 |
+
categories=categories,
|
121 |
+
details={
|
122 |
+
"original_length": len(content),
|
123 |
+
"filtered_length": len(filtered),
|
124 |
+
"rule_matches": len(matched_rules),
|
125 |
+
"context": context or {}
|
126 |
+
}
|
127 |
+
)
|
128 |
+
|
129 |
+
if matched_rules and self.security_logger:
|
130 |
+
self.security_logger.log_security_event(
|
131 |
+
"content_filtered",
|
132 |
+
matched_rules=matched_rules,
|
133 |
+
categories=[c.value for c in categories],
|
134 |
+
risk_score=risk_score,
|
135 |
+
is_allowed=is_allowed
|
136 |
+
)
|
137 |
+
|
138 |
+
return result
|
139 |
+
|
140 |
+
except Exception as e:
|
141 |
+
if self.security_logger:
|
142 |
+
self.security_logger.log_security_event(
|
143 |
+
"filter_error",
|
144 |
+
error=str(e),
|
145 |
+
content_length=len(content)
|
146 |
+
)
|
147 |
+
raise ValidationError(f"Content filtering failed: {str(e)}")
|
148 |
+
|
149 |
+
def add_rule(self, name: str, rule: FilterRule) -> None:
|
150 |
+
self.rules[name] = rule
|
151 |
+
self.compiled_rules[name] = re.compile(rule.pattern, re.IGNORECASE | re.MULTILINE)
|
152 |
+
|
153 |
+
def remove_rule(self, name: str) -> None:
|
154 |
+
self.rules.pop(name, None)
|
155 |
+
self.compiled_rules.pop(name, None)
|
156 |
+
|
157 |
+
def get_rules(self) -> Dict[str, Dict[str, Any]]:
|
158 |
+
return {
|
159 |
+
name: {
|
160 |
+
"pattern": rule.pattern,
|
161 |
+
"category": rule.category.value,
|
162 |
+
"severity": rule.severity,
|
163 |
+
"description": rule.description,
|
164 |
+
"action": rule.action
|
165 |
+
}
|
166 |
+
for name, rule in self.rules.items()
|
167 |
+
}
|
src/llmguardian/defenders/context_validator.py
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
defenders/context_validator.py - Context validation for LLM interactions
|
3 |
+
"""
|
4 |
+
|
5 |
+
from typing import Dict, Optional, List, Any
|
6 |
+
from dataclasses import dataclass
|
7 |
+
from datetime import datetime
|
8 |
+
import hashlib
|
9 |
+
from ..core.logger import SecurityLogger
|
10 |
+
from ..core.exceptions import ValidationError
|
11 |
+
|
12 |
+
@dataclass
|
13 |
+
class ContextRule:
|
14 |
+
max_age: int # seconds
|
15 |
+
required_fields: List[str]
|
16 |
+
forbidden_fields: List[str]
|
17 |
+
max_depth: int
|
18 |
+
checksum_fields: List[str]
|
19 |
+
|
20 |
+
@dataclass
|
21 |
+
class ValidationResult:
|
22 |
+
is_valid: bool
|
23 |
+
errors: List[str]
|
24 |
+
modified_context: Dict[str, Any]
|
25 |
+
metadata: Dict[str, Any]
|
26 |
+
|
27 |
+
class ContextValidator:
|
28 |
+
def __init__(self, security_logger: Optional[SecurityLogger] = None):
|
29 |
+
self.security_logger = security_logger
|
30 |
+
self.rule = ContextRule(
|
31 |
+
max_age=3600,
|
32 |
+
required_fields=["user_id", "session_id", "timestamp"],
|
33 |
+
forbidden_fields=["password", "secret", "token"],
|
34 |
+
max_depth=5,
|
35 |
+
checksum_fields=["user_id", "session_id"]
|
36 |
+
)
|
37 |
+
|
38 |
+
def validate_context(self, context: Dict[str, Any], previous_context: Optional[Dict[str, Any]] = None) -> ValidationResult:
|
39 |
+
try:
|
40 |
+
errors = []
|
41 |
+
modified = context.copy()
|
42 |
+
|
43 |
+
# Check required fields
|
44 |
+
missing = [f for f in self.rule.required_fields if f not in context]
|
45 |
+
if missing:
|
46 |
+
errors.append(f"Missing required fields: {missing}")
|
47 |
+
|
48 |
+
# Check forbidden fields
|
49 |
+
forbidden = [f for f in self.rule.forbidden_fields if f in context]
|
50 |
+
if forbidden:
|
51 |
+
errors.append(f"Forbidden fields present: {forbidden}")
|
52 |
+
for field in forbidden:
|
53 |
+
modified.pop(field, None)
|
54 |
+
|
55 |
+
# Validate timestamp
|
56 |
+
if "timestamp" in context:
|
57 |
+
age = (datetime.utcnow() - datetime.fromisoformat(str(context["timestamp"]))).seconds
|
58 |
+
if age > self.rule.max_age:
|
59 |
+
errors.append(f"Context too old: {age} seconds")
|
60 |
+
|
61 |
+
# Check context depth
|
62 |
+
if not self._check_depth(context, 0):
|
63 |
+
errors.append(f"Context exceeds max depth of {self.rule.max_depth}")
|
64 |
+
|
65 |
+
# Verify checksums if previous context exists
|
66 |
+
if previous_context:
|
67 |
+
if not self._verify_checksums(context, previous_context):
|
68 |
+
errors.append("Context checksum mismatch")
|
69 |
+
|
70 |
+
# Build metadata
|
71 |
+
metadata = {
|
72 |
+
"validation_time": datetime.utcnow().isoformat(),
|
73 |
+
"original_size": len(str(context)),
|
74 |
+
"modified_size": len(str(modified)),
|
75 |
+
"changes": len(errors)
|
76 |
+
}
|
77 |
+
|
78 |
+
result = ValidationResult(
|
79 |
+
is_valid=len(errors) == 0,
|
80 |
+
errors=errors,
|
81 |
+
modified_context=modified,
|
82 |
+
metadata=metadata
|
83 |
+
)
|
84 |
+
|
85 |
+
if errors and self.security_logger:
|
86 |
+
self.security_logger.log_security_event(
|
87 |
+
"context_validation_failure",
|
88 |
+
errors=errors,
|
89 |
+
context_id=context.get("context_id")
|
90 |
+
)
|
91 |
+
|
92 |
+
return result
|
93 |
+
|
94 |
+
except Exception as e:
|
95 |
+
if self.security_logger:
|
96 |
+
self.security_logger.log_security_event(
|
97 |
+
"context_validation_error",
|
98 |
+
error=str(e)
|
99 |
+
)
|
100 |
+
raise ValidationError(f"Context validation failed: {str(e)}")
|
101 |
+
|
102 |
+
def _check_depth(self, obj: Any, depth: int) -> bool:
|
103 |
+
if depth > self.rule.max_depth:
|
104 |
+
return False
|
105 |
+
if isinstance(obj, dict):
|
106 |
+
return all(self._check_depth(v, depth + 1) for v in obj.values())
|
107 |
+
if isinstance(obj, list):
|
108 |
+
return all(self._check_depth(v, depth + 1) for v in obj)
|
109 |
+
return True
|
110 |
+
|
111 |
+
def _verify_checksums(self, current: Dict[str, Any], previous: Dict[str, Any]) -> bool:
|
112 |
+
for field in self.rule.checksum_fields:
|
113 |
+
if field in current and field in previous:
|
114 |
+
current_hash = hashlib.sha256(str(current[field]).encode()).hexdigest()
|
115 |
+
previous_hash = hashlib.sha256(str(previous[field]).encode()).hexdigest()
|
116 |
+
if current_hash != previous_hash:
|
117 |
+
return False
|
118 |
+
return True
|
119 |
+
|
120 |
+
def update_rule(self, updates: Dict[str, Any]) -> None:
|
121 |
+
for key, value in updates.items():
|
122 |
+
if hasattr(self.rule, key):
|
123 |
+
setattr(self.rule, key, value)
|
src/llmguardian/defenders/input_sanitizer.py
ADDED
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
defenders/input_sanitizer.py - Input sanitization for LLM inputs
|
3 |
+
"""
|
4 |
+
|
5 |
+
import re
|
6 |
+
from typing import Dict, Any, List, Optional
|
7 |
+
from dataclasses import dataclass
|
8 |
+
from ..core.logger import SecurityLogger
|
9 |
+
from ..core.exceptions import ValidationError
|
10 |
+
|
11 |
+
@dataclass
|
12 |
+
class SanitizationRule:
|
13 |
+
pattern: str
|
14 |
+
replacement: str
|
15 |
+
description: str
|
16 |
+
enabled: bool = True
|
17 |
+
|
18 |
+
@dataclass
|
19 |
+
class SanitizationResult:
|
20 |
+
original: str
|
21 |
+
sanitized: str
|
22 |
+
applied_rules: List[str]
|
23 |
+
is_modified: bool
|
24 |
+
risk_level: str
|
25 |
+
|
26 |
+
class InputSanitizer:
|
27 |
+
def __init__(self, security_logger: Optional[SecurityLogger] = None):
|
28 |
+
self.security_logger = security_logger
|
29 |
+
self.rules = self._initialize_rules()
|
30 |
+
self.compiled_rules = {
|
31 |
+
name: re.compile(rule.pattern, re.IGNORECASE | re.MULTILINE)
|
32 |
+
for name, rule in self.rules.items()
|
33 |
+
if rule.enabled
|
34 |
+
}
|
35 |
+
|
36 |
+
def _initialize_rules(self) -> Dict[str, SanitizationRule]:
|
37 |
+
return {
|
38 |
+
"system_instructions": SanitizationRule(
|
39 |
+
pattern=r"system:\s*|instruction:\s*",
|
40 |
+
replacement=" ",
|
41 |
+
description="Remove system instruction markers"
|
42 |
+
),
|
43 |
+
"code_injection": SanitizationRule(
|
44 |
+
pattern=r"<script.*?>.*?</script>",
|
45 |
+
replacement="",
|
46 |
+
description="Remove script tags"
|
47 |
+
),
|
48 |
+
"delimiter_injection": SanitizationRule(
|
49 |
+
pattern=r"[<\[{](?:system|prompt|instruction)[>\]}]",
|
50 |
+
replacement="",
|
51 |
+
description="Remove delimiter-based injections"
|
52 |
+
),
|
53 |
+
"command_injection": SanitizationRule(
|
54 |
+
pattern=r"(?:exec|eval|system)\s*\(",
|
55 |
+
replacement="",
|
56 |
+
description="Remove command execution attempts"
|
57 |
+
),
|
58 |
+
"encoding_patterns": SanitizationRule(
|
59 |
+
pattern=r"(?:base64|hex|rot13)\s*\(",
|
60 |
+
replacement="",
|
61 |
+
description="Remove encoding attempts"
|
62 |
+
),
|
63 |
+
}
|
64 |
+
|
65 |
+
def sanitize(self, input_text: str, context: Optional[Dict[str, Any]] = None) -> SanitizationResult:
|
66 |
+
original = input_text
|
67 |
+
applied_rules = []
|
68 |
+
is_modified = False
|
69 |
+
|
70 |
+
try:
|
71 |
+
sanitized = input_text
|
72 |
+
for name, rule in self.rules.items():
|
73 |
+
if not rule.enabled:
|
74 |
+
continue
|
75 |
+
|
76 |
+
pattern = self.compiled_rules.get(name)
|
77 |
+
if not pattern:
|
78 |
+
continue
|
79 |
+
|
80 |
+
new_text = pattern.sub(rule.replacement, sanitized)
|
81 |
+
if new_text != sanitized:
|
82 |
+
applied_rules.append(name)
|
83 |
+
is_modified = True
|
84 |
+
sanitized = new_text
|
85 |
+
|
86 |
+
risk_level = self._assess_risk(applied_rules)
|
87 |
+
|
88 |
+
if is_modified and self.security_logger:
|
89 |
+
self.security_logger.log_security_event(
|
90 |
+
"input_sanitization",
|
91 |
+
original_length=len(original),
|
92 |
+
sanitized_length=len(sanitized),
|
93 |
+
applied_rules=applied_rules,
|
94 |
+
risk_level=risk_level
|
95 |
+
)
|
96 |
+
|
97 |
+
return SanitizationResult(
|
98 |
+
original=original,
|
99 |
+
sanitized=sanitized,
|
100 |
+
applied_rules=applied_rules,
|
101 |
+
is_modified=is_modified,
|
102 |
+
risk_level=risk_level
|
103 |
+
)
|
104 |
+
|
105 |
+
except Exception as e:
|
106 |
+
if self.security_logger:
|
107 |
+
self.security_logger.log_security_event(
|
108 |
+
"sanitization_error",
|
109 |
+
error=str(e),
|
110 |
+
input_length=len(input_text)
|
111 |
+
)
|
112 |
+
raise ValidationError(f"Sanitization failed: {str(e)}")
|
113 |
+
|
114 |
+
def _assess_risk(self, applied_rules: List[str]) -> str:
|
115 |
+
if not applied_rules:
|
116 |
+
return "low"
|
117 |
+
if len(applied_rules) >= 3:
|
118 |
+
return "high"
|
119 |
+
if "command_injection" in applied_rules or "code_injection" in applied_rules:
|
120 |
+
return "high"
|
121 |
+
return "medium"
|
122 |
+
|
123 |
+
def add_rule(self, name: str, rule: SanitizationRule) -> None:
|
124 |
+
self.rules[name] = rule
|
125 |
+
if rule.enabled:
|
126 |
+
self.compiled_rules[name] = re.compile(rule.pattern, re.IGNORECASE | re.MULTILINE)
|
127 |
+
|
128 |
+
def remove_rule(self, name: str) -> None:
|
129 |
+
self.rules.pop(name, None)
|
130 |
+
self.compiled_rules.pop(name, None)
|
131 |
+
|
132 |
+
def get_rules(self) -> Dict[str, Dict[str, Any]]:
|
133 |
+
return {
|
134 |
+
name: {
|
135 |
+
"pattern": rule.pattern,
|
136 |
+
"replacement": rule.replacement,
|
137 |
+
"description": rule.description,
|
138 |
+
"enabled": rule.enabled
|
139 |
+
}
|
140 |
+
for name, rule in self.rules.items()
|
141 |
+
}
|
src/llmguardian/defenders/output_validator.py
ADDED
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
defenders/output_validator.py - Output validation and sanitization
|
3 |
+
"""
|
4 |
+
|
5 |
+
import re
|
6 |
+
from typing import Dict, List, Optional, Set, Any
|
7 |
+
from dataclasses import dataclass
|
8 |
+
from ..core.logger import SecurityLogger
|
9 |
+
from ..core.exceptions import ValidationError
|
10 |
+
|
11 |
+
@dataclass
|
12 |
+
class ValidationRule:
|
13 |
+
pattern: str
|
14 |
+
description: str
|
15 |
+
severity: int # 1-10
|
16 |
+
block: bool = True
|
17 |
+
sanitize: bool = True
|
18 |
+
replacement: str = ""
|
19 |
+
|
20 |
+
@dataclass
|
21 |
+
class ValidationResult:
|
22 |
+
is_valid: bool
|
23 |
+
violations: List[str]
|
24 |
+
sanitized_output: Optional[str]
|
25 |
+
risk_score: int
|
26 |
+
details: Dict[str, Any]
|
27 |
+
|
28 |
+
class OutputValidator:
|
29 |
+
def __init__(self, security_logger: Optional[SecurityLogger] = None):
|
30 |
+
self.security_logger = security_logger
|
31 |
+
self.rules = self._initialize_rules()
|
32 |
+
self.compiled_rules = {
|
33 |
+
name: re.compile(rule.pattern, re.IGNORECASE | re.MULTILINE)
|
34 |
+
for name, rule in self.rules.items()
|
35 |
+
}
|
36 |
+
self.sensitive_patterns = self._initialize_sensitive_patterns()
|
37 |
+
|
38 |
+
def _initialize_rules(self) -> Dict[str, ValidationRule]:
|
39 |
+
return {
|
40 |
+
"sql_injection": ValidationRule(
|
41 |
+
pattern=r"(?:SELECT|INSERT|UPDATE|DELETE)\s+(?:FROM|INTO)\s+\w+",
|
42 |
+
description="SQL query in output",
|
43 |
+
severity=9,
|
44 |
+
block=True
|
45 |
+
),
|
46 |
+
"code_injection": ValidationRule(
|
47 |
+
pattern=r"<script.*?>.*?</script>",
|
48 |
+
description="JavaScript code in output",
|
49 |
+
severity=8,
|
50 |
+
block=True
|
51 |
+
),
|
52 |
+
"system_info": ValidationRule(
|
53 |
+
pattern=r"(?:system|config|env|secret)(?:_|\s+)?(?:key|token|password)",
|
54 |
+
description="System information leak",
|
55 |
+
severity=9,
|
56 |
+
block=True
|
57 |
+
),
|
58 |
+
"personal_data": ValidationRule(
|
59 |
+
pattern=r"\b\d{3}-\d{2}-\d{4}\b|\b\d{16}\b",
|
60 |
+
description="Personal data (SSN/CC)",
|
61 |
+
severity=10,
|
62 |
+
block=True
|
63 |
+
),
|
64 |
+
"file_paths": ValidationRule(
|
65 |
+
pattern=r"(?:/[\w./]+)|(?:C:\\[\w\\]+)",
|
66 |
+
description="File system paths",
|
67 |
+
severity=7,
|
68 |
+
block=True
|
69 |
+
),
|
70 |
+
"html_content": ValidationRule(
|
71 |
+
pattern=r"<(?!br|p|b|i|em|strong)[^>]+>",
|
72 |
+
description="HTML content",
|
73 |
+
severity=6,
|
74 |
+
sanitize=True,
|
75 |
+
replacement=""
|
76 |
+
),
|
77 |
+
}
|
78 |
+
|
79 |
+
def _initialize_sensitive_patterns(self) -> Set[str]:
|
80 |
+
return {
|
81 |
+
r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b", # Email
|
82 |
+
r"\b\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}\b", # IP address
|
83 |
+
r"(?i)api[_-]?key", # API keys
|
84 |
+
r"(?i)password|passwd|pwd", # Passwords
|
85 |
+
r"(?i)token|secret|credential", # Credentials
|
86 |
+
r"\b[A-Z0-9]{20,}\b", # Long alphanumeric strings
|
87 |
+
}
|
88 |
+
|
89 |
+
def validate(self, output: str, context: Optional[Dict[str, Any]] = None) -> ValidationResult:
|
90 |
+
try:
|
91 |
+
violations = []
|
92 |
+
risk_score = 0
|
93 |
+
sanitized = output
|
94 |
+
is_valid = True
|
95 |
+
|
96 |
+
# Check against validation rules
|
97 |
+
for name, rule in self.rules.items():
|
98 |
+
pattern = self.compiled_rules[name]
|
99 |
+
matches = pattern.findall(sanitized)
|
100 |
+
|
101 |
+
if matches:
|
102 |
+
violations.append(f"{name}: {rule.description}")
|
103 |
+
risk_score = max(risk_score, rule.severity)
|
104 |
+
|
105 |
+
if rule.block:
|
106 |
+
is_valid = False
|
107 |
+
|
108 |
+
if rule.sanitize:
|
109 |
+
sanitized = pattern.sub(rule.replacement, sanitized)
|
110 |
+
|
111 |
+
# Check for sensitive data patterns
|
112 |
+
for pattern in self.sensitive_patterns:
|
113 |
+
matches = re.findall(pattern, sanitized)
|
114 |
+
if matches:
|
115 |
+
violations.append(f"Sensitive data detected: {pattern}")
|
116 |
+
risk_score = max(risk_score, 8)
|
117 |
+
is_valid = False
|
118 |
+
sanitized = re.sub(pattern, "[REDACTED]", sanitized)
|
119 |
+
|
120 |
+
result = ValidationResult(
|
121 |
+
is_valid=is_valid,
|
122 |
+
violations=violations,
|
123 |
+
sanitized_output=sanitized if violations else output,
|
124 |
+
risk_score=risk_score,
|
125 |
+
details={
|
126 |
+
"original_length": len(output),
|
127 |
+
"sanitized_length": len(sanitized),
|
128 |
+
"violation_count": len(violations),
|
129 |
+
"context": context or {}
|
130 |
+
}
|
131 |
+
)
|
132 |
+
|
133 |
+
if violations and self.security_logger:
|
134 |
+
self.security_logger.log_security_event(
|
135 |
+
"output_validation",
|
136 |
+
violations=violations,
|
137 |
+
risk_score=risk_score,
|
138 |
+
is_valid=is_valid
|
139 |
+
)
|
140 |
+
|
141 |
+
return result
|
142 |
+
|
143 |
+
except Exception as e:
|
144 |
+
if self.security_logger:
|
145 |
+
self.security_logger.log_security_event(
|
146 |
+
"validation_error",
|
147 |
+
error=str(e),
|
148 |
+
output_length=len(output)
|
149 |
+
)
|
150 |
+
raise ValidationError(f"Output validation failed: {str(e)}")
|
151 |
+
|
152 |
+
def add_rule(self, name: str, rule: ValidationRule) -> None:
|
153 |
+
self.rules[name] = rule
|
154 |
+
self.compiled_rules[name] = re.compile(rule.pattern, re.IGNORECASE | re.MULTILINE)
|
155 |
+
|
156 |
+
def remove_rule(self, name: str) -> None:
|
157 |
+
self.rules.pop(name, None)
|
158 |
+
self.compiled_rules.pop(name, None)
|
159 |
+
|
160 |
+
def add_sensitive_pattern(self, pattern: str) -> None:
|
161 |
+
self.sensitive_patterns.add(pattern)
|
162 |
+
|
163 |
+
def get_rules(self) -> Dict[str, Dict[str, Any]]:
|
164 |
+
return {
|
165 |
+
name: {
|
166 |
+
"pattern": rule.pattern,
|
167 |
+
"description": rule.description,
|
168 |
+
"severity": rule.severity,
|
169 |
+
"block": rule.block,
|
170 |
+
"sanitize": rule.sanitize
|
171 |
+
}
|
172 |
+
for name, rule in self.rules.items()
|
173 |
+
}
|
src/llmguardian/defenders/test_context_validator.py
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
tests/defenders/test_context_validator.py - Tests for context validation
|
3 |
+
"""
|
4 |
+
|
5 |
+
import pytest
|
6 |
+
from datetime import datetime, timedelta
|
7 |
+
from llmguardian.defenders.context_validator import ContextValidator, ValidationResult
|
8 |
+
from llmguardian.core.exceptions import ValidationError
|
9 |
+
|
10 |
+
@pytest.fixture
|
11 |
+
def validator():
|
12 |
+
return ContextValidator()
|
13 |
+
|
14 |
+
@pytest.fixture
|
15 |
+
def valid_context():
|
16 |
+
return {
|
17 |
+
"user_id": "test_user",
|
18 |
+
"session_id": "test_session",
|
19 |
+
"timestamp": datetime.utcnow().isoformat(),
|
20 |
+
"request_id": "123",
|
21 |
+
"metadata": {
|
22 |
+
"source": "test",
|
23 |
+
"version": "1.0"
|
24 |
+
}
|
25 |
+
}
|
26 |
+
|
27 |
+
def test_valid_context(validator, valid_context):
|
28 |
+
result = validator.validate_context(valid_context)
|
29 |
+
assert result.is_valid
|
30 |
+
assert not result.errors
|
31 |
+
assert result.modified_context == valid_context
|
32 |
+
|
33 |
+
def test_missing_required_fields(validator):
|
34 |
+
context = {
|
35 |
+
"user_id": "test_user",
|
36 |
+
"timestamp": datetime.utcnow().isoformat()
|
37 |
+
}
|
38 |
+
result = validator.validate_context(context)
|
39 |
+
assert not result.is_valid
|
40 |
+
assert "Missing required fields" in result.errors[0]
|
41 |
+
|
42 |
+
def test_forbidden_fields(validator, valid_context):
|
43 |
+
context = valid_context.copy()
|
44 |
+
context["password"] = "secret123"
|
45 |
+
result = validator.validate_context(context)
|
46 |
+
assert not result.is_valid
|
47 |
+
assert "Forbidden fields present" in result.errors[0]
|
48 |
+
assert "password" not in result.modified_context
|
49 |
+
|
50 |
+
def test_context_age(validator, valid_context):
|
51 |
+
old_context = valid_context.copy()
|
52 |
+
old_context["timestamp"] = (
|
53 |
+
datetime.utcnow() - timedelta(hours=2)
|
54 |
+
).isoformat()
|
55 |
+
result = validator.validate_context(old_context)
|
56 |
+
assert not result.is_valid
|
57 |
+
assert "Context too old" in result.errors[0]
|
58 |
+
|
59 |
+
def test_context_depth(validator, valid_context):
|
60 |
+
deep_context = valid_context.copy()
|
61 |
+
current = deep_context
|
62 |
+
for i in range(10):
|
63 |
+
current["nested"] = {}
|
64 |
+
current = current["nested"]
|
65 |
+
result = validator.validate_context(deep_context)
|
66 |
+
assert not result.is_valid
|
67 |
+
assert "Context exceeds max depth" in result.errors[0]
|
68 |
+
|
69 |
+
def test_checksum_verification(validator, valid_context):
|
70 |
+
previous_context = valid_context.copy()
|
71 |
+
modified_context = valid_context.copy()
|
72 |
+
modified_context["user_id"] = "different_user"
|
73 |
+
result = validator.validate_context(modified_context, previous_context)
|
74 |
+
assert not result.is_valid
|
75 |
+
assert "Context checksum mismatch" in result.errors[0]
|
76 |
+
|
77 |
+
def test_update_rule(validator):
|
78 |
+
validator.update_rule({"max_age": 7200})
|
79 |
+
old_context = {
|
80 |
+
"user_id": "test_user",
|
81 |
+
"session_id": "test_session",
|
82 |
+
"timestamp": (
|
83 |
+
datetime.utcnow() - timedelta(hours=1.5)
|
84 |
+
).isoformat()
|
85 |
+
}
|
86 |
+
result = validator.validate_context(old_context)
|
87 |
+
assert result.is_valid
|
88 |
+
|
89 |
+
def test_exception_handling(validator):
|
90 |
+
with pytest.raises(ValidationError):
|
91 |
+
validator.validate_context({"timestamp": "invalid_date"})
|
92 |
+
|
93 |
+
def test_metadata_generation(validator, valid_context):
|
94 |
+
result = validator.validate_context(valid_context)
|
95 |
+
assert "validation_time" in result.metadata
|
96 |
+
assert "original_size" in result.metadata
|
97 |
+
assert "modified_size" in result.metadata
|
98 |
+
assert "changes" in result.metadata
|
src/llmguardian/defenders/token_validator.py
ADDED
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
defenders/token_validator.py - Token and credential validation
|
3 |
+
"""
|
4 |
+
|
5 |
+
from typing import Dict, Optional, Any, List
|
6 |
+
from dataclasses import dataclass
|
7 |
+
import re
|
8 |
+
import jwt
|
9 |
+
from datetime import datetime, timedelta
|
10 |
+
from ..core.logger import SecurityLogger
|
11 |
+
from ..core.exceptions import TokenValidationError
|
12 |
+
|
13 |
+
@dataclass
|
14 |
+
class TokenRule:
|
15 |
+
pattern: str
|
16 |
+
description: str
|
17 |
+
min_length: int
|
18 |
+
max_length: int
|
19 |
+
required_chars: str
|
20 |
+
expiry_time: int # in seconds
|
21 |
+
|
22 |
+
@dataclass
|
23 |
+
class TokenValidationResult:
|
24 |
+
is_valid: bool
|
25 |
+
errors: List[str]
|
26 |
+
metadata: Dict[str, Any]
|
27 |
+
expiry: Optional[datetime]
|
28 |
+
|
29 |
+
class TokenValidator:
|
30 |
+
def __init__(self, security_logger: Optional[SecurityLogger] = None):
|
31 |
+
self.security_logger = security_logger
|
32 |
+
self.rules = self._initialize_rules()
|
33 |
+
self.secret_key = self._load_secret_key()
|
34 |
+
|
35 |
+
def _initialize_rules(self) -> Dict[str, TokenRule]:
|
36 |
+
return {
|
37 |
+
"jwt": TokenRule(
|
38 |
+
pattern=r"^[A-Za-z0-9-_=]+\.[A-Za-z0-9-_=]+\.[A-Za-z0-9-_.+/=]+$",
|
39 |
+
description="JWT token",
|
40 |
+
min_length=32,
|
41 |
+
max_length=4096,
|
42 |
+
required_chars=".-_",
|
43 |
+
expiry_time=3600
|
44 |
+
),
|
45 |
+
"api_key": TokenRule(
|
46 |
+
pattern=r"^[A-Za-z0-9]{32,64}$",
|
47 |
+
description="API key",
|
48 |
+
min_length=32,
|
49 |
+
max_length=64,
|
50 |
+
required_chars="",
|
51 |
+
expiry_time=86400
|
52 |
+
),
|
53 |
+
"session_token": TokenRule(
|
54 |
+
pattern=r"^[A-Fa-f0-9]{64}$",
|
55 |
+
description="Session token",
|
56 |
+
min_length=64,
|
57 |
+
max_length=64,
|
58 |
+
required_chars="",
|
59 |
+
expiry_time=7200
|
60 |
+
)
|
61 |
+
}
|
62 |
+
|
63 |
+
def _load_secret_key(self) -> bytes:
|
64 |
+
# Implementation would load from secure storage
|
65 |
+
return b"your-256-bit-secret"
|
66 |
+
|
67 |
+
def validate_token(self, token: str, token_type: str) -> TokenValidationResult:
|
68 |
+
try:
|
69 |
+
if token_type not in self.rules:
|
70 |
+
raise TokenValidationError(f"Unknown token type: {token_type}")
|
71 |
+
|
72 |
+
rule = self.rules[token_type]
|
73 |
+
errors = []
|
74 |
+
metadata = {}
|
75 |
+
|
76 |
+
# Length validation
|
77 |
+
if len(token) < rule.min_length or len(token) > rule.max_length:
|
78 |
+
errors.append(f"Token length must be between {rule.min_length} and {rule.max_length}")
|
79 |
+
|
80 |
+
# Pattern validation
|
81 |
+
if not re.match(rule.pattern, token):
|
82 |
+
errors.append("Token format is invalid")
|
83 |
+
|
84 |
+
# Required characters
|
85 |
+
if rule.required_chars:
|
86 |
+
missing_chars = set(rule.required_chars) - set(token)
|
87 |
+
if missing_chars:
|
88 |
+
errors.append(f"Token missing required characters: {missing_chars}")
|
89 |
+
|
90 |
+
# JWT-specific validation
|
91 |
+
if token_type == "jwt":
|
92 |
+
try:
|
93 |
+
payload = jwt.decode(token, self.secret_key, algorithms=["HS256"])
|
94 |
+
metadata = payload
|
95 |
+
exp = datetime.fromtimestamp(payload.get("exp", 0))
|
96 |
+
if exp < datetime.utcnow():
|
97 |
+
errors.append("Token has expired")
|
98 |
+
except jwt.InvalidTokenError as e:
|
99 |
+
errors.append(f"Invalid JWT: {str(e)}")
|
100 |
+
|
101 |
+
is_valid = len(errors) == 0
|
102 |
+
expiry = datetime.utcnow() + timedelta(seconds=rule.expiry_time)
|
103 |
+
|
104 |
+
if not is_valid and self.security_logger:
|
105 |
+
self.security_logger.log_security_event(
|
106 |
+
"token_validation_failure",
|
107 |
+
token_type=token_type,
|
108 |
+
errors=errors
|
109 |
+
)
|
110 |
+
|
111 |
+
return TokenValidationResult(
|
112 |
+
is_valid=is_valid,
|
113 |
+
errors=errors,
|
114 |
+
metadata=metadata,
|
115 |
+
expiry=expiry if is_valid else None
|
116 |
+
)
|
117 |
+
|
118 |
+
except Exception as e:
|
119 |
+
if self.security_logger:
|
120 |
+
self.security_logger.log_security_event(
|
121 |
+
"token_validation_error",
|
122 |
+
error=str(e)
|
123 |
+
)
|
124 |
+
raise TokenValidationError(f"Validation failed: {str(e)}")
|
125 |
+
|
126 |
+
def create_token(self, token_type: str, payload: Dict[str, Any]) -> str:
|
127 |
+
if token_type not in self.rules:
|
128 |
+
raise TokenValidationError(f"Unknown token type: {token_type}")
|
129 |
+
|
130 |
+
try:
|
131 |
+
if token_type == "jwt":
|
132 |
+
expiry = datetime.utcnow() + timedelta(
|
133 |
+
seconds=self.rules[token_type].expiry_time
|
134 |
+
)
|
135 |
+
payload["exp"] = expiry.timestamp()
|
136 |
+
return jwt.encode(payload, self.secret_key, algorithm="HS256")
|
137 |
+
|
138 |
+
# Add other token type creation logic here
|
139 |
+
raise TokenValidationError(f"Token creation not implemented for {token_type}")
|
140 |
+
|
141 |
+
except Exception as e:
|
142 |
+
if self.security_logger:
|
143 |
+
self.security_logger.log_security_event(
|
144 |
+
"token_creation_error",
|
145 |
+
error=str(e)
|
146 |
+
)
|
147 |
+
raise TokenValidationError(f"Token creation failed: {str(e)}")
|