DeWitt Gibson
Adding core functionality
740932b
"""
core/security.py - Core security services for LLMGuardian
"""
import hashlib
import hmac
import secrets
from typing import Optional, Dict, Any, List
from dataclasses import dataclass
from datetime import datetime, timedelta
import jwt
from .config import Config
from .logger import SecurityLogger, AuditLogger
@dataclass
class SecurityContext:
"""Security context for requests"""
user_id: str
roles: List[str]
permissions: List[str]
session_id: str
timestamp: datetime
class RateLimiter:
"""Rate limiting implementation"""
def __init__(self, max_requests: int, time_window: int):
self.max_requests = max_requests
self.time_window = time_window
self.requests = {}
def is_allowed(self, key: str) -> bool:
"""Check if request is allowed under rate limit"""
now = datetime.utcnow()
request_history = self.requests.get(key, [])
# Clean old requests
request_history = [time for time in request_history
if now - time < timedelta(seconds=self.time_window)]
# Check rate limit
if len(request_history) >= self.max_requests:
return False
# Update history
request_history.append(now)
self.requests[key] = request_history
return True
class SecurityService:
"""Core security service"""
def __init__(self, config: Config,
security_logger: SecurityLogger,
audit_logger: AuditLogger):
"""Initialize security service"""
self.config = config
self.security_logger = security_logger
self.audit_logger = audit_logger
self.rate_limiter = RateLimiter(
config.security.rate_limit,
60 # 1 minute window
)
self.secret_key = self._load_or_generate_key()
def _load_or_generate_key(self) -> bytes:
"""Load or generate secret key"""
try:
with open(".secret_key", "rb") as f:
return f.read()
except FileNotFoundError:
key = secrets.token_bytes(32)
with open(".secret_key", "wb") as f:
f.write(key)
return key
def create_security_context(self, user_id: str,
roles: List[str],
permissions: List[str]) -> SecurityContext:
"""Create a new security context"""
return SecurityContext(
user_id=user_id,
roles=roles,
permissions=permissions,
session_id=secrets.token_urlsafe(16),
timestamp=datetime.utcnow()
)
def validate_request(self, context: SecurityContext,
resource: str, action: str) -> bool:
"""Validate request against security context"""
# Check rate limiting
if not self.rate_limiter.is_allowed(context.user_id):
self.security_logger.log_security_event(
"rate_limit_exceeded",
user_id=context.user_id
)
return False
# Log access attempt
self.audit_logger.log_access(
user=context.user_id,
resource=resource,
action=action
)
return True
def create_token(self, context: SecurityContext) -> str:
"""Create JWT token from security context"""
payload = {
"user_id": context.user_id,
"roles": context.roles,
"permissions": context.permissions,
"session_id": context.session_id,
"timestamp": context.timestamp.isoformat(),
"exp": datetime.utcnow() + timedelta(hours=1)
}
return jwt.encode(payload, self.secret_key, algorithm="HS256")
def verify_token(self, token: str) -> Optional[SecurityContext]:
"""Verify and decode JWT token"""
try:
payload = jwt.decode(token, self.secret_key, algorithms=["HS256"])
return SecurityContext(
user_id=payload["user_id"],
roles=payload["roles"],
permissions=payload["permissions"],
session_id=payload["session_id"],
timestamp=datetime.fromisoformat(payload["timestamp"])
)
except jwt.InvalidTokenError:
self.security_logger.log_security_event(
"invalid_token",
token=token[:10] + "..." # Log partial token for tracking
)
return None
def hash_sensitive_data(self, data: str) -> str:
"""Hash sensitive data using SHA-256"""
return hashlib.sha256(data.encode()).hexdigest()
def generate_hmac(self, data: str) -> str:
"""Generate HMAC for data integrity"""
return hmac.new(
self.secret_key,
data.encode(),
hashlib.sha256
).hexdigest()
def verify_hmac(self, data: str, signature: str) -> bool:
"""Verify HMAC signature"""
expected = self.generate_hmac(data)
return hmac.compare_digest(expected, signature)
def audit_configuration_change(self, user: str,
old_config: Dict[str, Any],
new_config: Dict[str, Any]) -> None:
"""Audit configuration changes"""
changes = {
k: {"old": old_config.get(k), "new": v}
for k, v in new_config.items()
if v != old_config.get(k)
}
self.audit_logger.log_configuration_change(user, changes)
if any(k.startswith("security.") for k in changes):
self.security_logger.log_security_event(
"security_config_change",
user=user,
changes={k: v for k, v in changes.items()
if k.startswith("security.")}
)
def validate_prompt_security(self, prompt: str,
context: SecurityContext) -> Dict[str, Any]:
"""Validate prompt against security rules"""
results = {
"allowed": True,
"warnings": [],
"blocked_reasons": []
}
# Check prompt length
if len(prompt) > self.config.security.max_token_length:
results["blocked_reasons"].append("Prompt exceeds maximum length")
results["allowed"] = False
# Rate limiting check
if not self.rate_limiter.is_allowed(context.user_id):
results["blocked_reasons"].append("Rate limit exceeded")
results["allowed"] = False
# Log validation result
self.security_logger.log_validation(
"prompt_security",
{
"user_id": context.user_id,
"prompt_length": len(prompt),
"results": results
}
)
return results
def check_permission(self, context: SecurityContext,
required_permission: str) -> bool:
"""Check if context has required permission"""
return required_permission in context.permissions
def sanitize_output(self, output: str) -> str:
"""Sanitize LLM output for security"""
# Implementation would depend on specific security requirements
# This is a basic example
sanitized = output
# Remove potential command injections
sanitized = sanitized.replace("sudo ", "")
sanitized = sanitized.replace("rm -rf", "")
# Remove potential SQL injections
sanitized = sanitized.replace("DROP TABLE", "")
sanitized = sanitized.replace("DELETE FROM", "")
return sanitized
class SecurityPolicy:
"""Security policy management"""
def __init__(self):
self.policies = {}
def add_policy(self, name: str, policy: Dict[str, Any]) -> None:
"""Add a security policy"""
self.policies[name] = policy
def check_policy(self, name: str, context: Dict[str, Any]) -> bool:
"""Check if context meets policy requirements"""
if name not in self.policies:
return False
policy = self.policies[name]
return all(
context.get(k) == v
for k, v in policy.items()
)
class SecurityMetrics:
"""Security metrics tracking"""
def __init__(self):
self.metrics = {
"requests": 0,
"blocked_requests": 0,
"warnings": 0,
"rate_limits": 0
}
def increment(self, metric: str) -> None:
"""Increment a metric counter"""
if metric in self.metrics:
self.metrics[metric] += 1
def get_metrics(self) -> Dict[str, int]:
"""Get current metrics"""
return self.metrics.copy()
def reset_metrics(self) -> None:
"""Reset all metrics to zero"""
for key in self.metrics:
self.metrics[key] = 0
class SecurityEvent:
"""Security event representation"""
def __init__(self, event_type: str, severity: int,
details: Dict[str, Any]):
self.event_type = event_type
self.severity = severity
self.details = details
self.timestamp = datetime.utcnow()
def to_dict(self) -> Dict[str, Any]:
"""Convert event to dictionary"""
return {
"event_type": self.event_type,
"severity": self.severity,
"details": self.details,
"timestamp": self.timestamp.isoformat()
}
class SecurityMonitor:
"""Security monitoring service"""
def __init__(self, security_logger: SecurityLogger):
self.security_logger = security_logger
self.metrics = SecurityMetrics()
self.events = []
self.alert_threshold = 5 # Number of high-severity events before alerting
def monitor_event(self, event: SecurityEvent) -> None:
"""Monitor a security event"""
self.events.append(event)
if event.severity >= 8: # High severity
self.metrics.increment("high_severity_events")
# Check if we need to trigger an alert
high_severity_count = sum(
1 for e in self.events[-10:] # Look at last 10 events
if e.severity >= 8
)
if high_severity_count >= self.alert_threshold:
self.trigger_alert("High severity event threshold exceeded")
def trigger_alert(self, reason: str) -> None:
"""Trigger a security alert"""
self.security_logger.log_security_event(
"security_alert",
reason=reason,
recent_events=[e.to_dict() for e in self.events[-10:]]
)
if __name__ == "__main__":
# Example usage
config = Config()
security_logger, audit_logger = setup_logging()
security_service = SecurityService(config, security_logger, audit_logger)
# Create security context
context = security_service.create_security_context(
user_id="test_user",
roles=["user"],
permissions=["read", "write"]
)
# Create and verify token
token = security_service.create_token(context)
verified_context = security_service.verify_token(token)
# Validate request
is_valid = security_service.validate_request(
context,
resource="api/data",
action="read"
)
print(f"Request validation result: {is_valid}")