File size: 4,448 Bytes
4a3bb46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
"""
defenders/context_validator.py - Context validation for LLM interactions
"""

from typing import Dict, Optional, List, Any
from dataclasses import dataclass
from datetime import datetime
import hashlib
from ..core.logger import SecurityLogger
from ..core.exceptions import ValidationError

@dataclass
class ContextRule:
   max_age: int  # seconds
   required_fields: List[str] 
   forbidden_fields: List[str]
   max_depth: int
   checksum_fields: List[str]

@dataclass
class ValidationResult:
   is_valid: bool
   errors: List[str]
   modified_context: Dict[str, Any]
   metadata: Dict[str, Any]

class ContextValidator:
   def __init__(self, security_logger: Optional[SecurityLogger] = None):
       self.security_logger = security_logger
       self.rule = ContextRule(
           max_age=3600,
           required_fields=["user_id", "session_id", "timestamp"],
           forbidden_fields=["password", "secret", "token"],
           max_depth=5,
           checksum_fields=["user_id", "session_id"]
       )

   def validate_context(self, context: Dict[str, Any], previous_context: Optional[Dict[str, Any]] = None) -> ValidationResult:
       try:
           errors = []
           modified = context.copy()

           # Check required fields
           missing = [f for f in self.rule.required_fields if f not in context]
           if missing:
               errors.append(f"Missing required fields: {missing}")

           # Check forbidden fields
           forbidden = [f for f in self.rule.forbidden_fields if f in context]
           if forbidden:
               errors.append(f"Forbidden fields present: {forbidden}")
               for field in forbidden:
                   modified.pop(field, None)

           # Validate timestamp
           if "timestamp" in context:
               age = (datetime.utcnow() - datetime.fromisoformat(str(context["timestamp"]))).seconds
               if age > self.rule.max_age:
                   errors.append(f"Context too old: {age} seconds")

           # Check context depth
           if not self._check_depth(context, 0):
               errors.append(f"Context exceeds max depth of {self.rule.max_depth}")

           # Verify checksums if previous context exists
           if previous_context:
               if not self._verify_checksums(context, previous_context):
                   errors.append("Context checksum mismatch")

           # Build metadata
           metadata = {
               "validation_time": datetime.utcnow().isoformat(),
               "original_size": len(str(context)),
               "modified_size": len(str(modified)),
               "changes": len(errors)
           }

           result = ValidationResult(
               is_valid=len(errors) == 0,
               errors=errors,
               modified_context=modified,
               metadata=metadata
           )

           if errors and self.security_logger:
               self.security_logger.log_security_event(
                   "context_validation_failure",
                   errors=errors,
                   context_id=context.get("context_id")
               )

           return result

       except Exception as e:
           if self.security_logger:
               self.security_logger.log_security_event(
                   "context_validation_error",
                   error=str(e)
               )
           raise ValidationError(f"Context validation failed: {str(e)}")

   def _check_depth(self, obj: Any, depth: int) -> bool:
       if depth > self.rule.max_depth:
           return False
       if isinstance(obj, dict):
           return all(self._check_depth(v, depth + 1) for v in obj.values())
       if isinstance(obj, list):
           return all(self._check_depth(v, depth + 1) for v in obj)
       return True

   def _verify_checksums(self, current: Dict[str, Any], previous: Dict[str, Any]) -> bool:
       for field in self.rule.checksum_fields:
           if field in current and field in previous:
               current_hash = hashlib.sha256(str(current[field]).encode()).hexdigest()
               previous_hash = hashlib.sha256(str(previous[field]).encode()).hexdigest()
               if current_hash != previous_hash:
                   return False
       return True

   def update_rule(self, updates: Dict[str, Any]) -> None:
       for key, value in updates.items():
           if hasattr(self.rule, key):
               setattr(self.rule, key, value)