ICH-Detection-Pipeline / security.py
Harshit Ghosh
Fix CSP to allow Hugging Face framing
19dbcd4
"""
Security utilities: headers, CSRF protection, input validation
"""
import os
import logging
from flask import request
from datetime import datetime, timedelta
from zoneinfo import ZoneInfo
logger = logging.getLogger(__name__)
IST = ZoneInfo("Asia/Kolkata")
def now_ist() -> datetime:
return datetime.now(IST).replace(tzinfo=None)
def init_security(app):
"""Initialize security features for Flask app"""
@app.before_request
def set_security_headers():
"""Add security headers to all responses"""
pass # Headers are set in after_request
@app.after_request
def add_security_headers(response):
"""Add security headers to all responses"""
# Prevent clickjacking attacks
# X-Frame-Options is removed because it prevents Hugging Face from embedding the app.
# We rely on CSP frame-ancestors instead.
# Prevent MIME type sniffing
response.headers['X-Content-Type-Options'] = 'nosniff'
# Enable XSS protection in older browsers
response.headers['X-XSS-Protection'] = '1; mode=block'
# Content Security Policy - restrictive but functional
csp = (
"default-src 'self'; "
"script-src 'self' 'unsafe-inline'; " # Minimal unsafe-inline for compatibility
"style-src 'self' 'unsafe-inline' https://fonts.googleapis.com; "
"font-src 'self' https://fonts.gstatic.com; "
"img-src 'self' data: https://res.cloudinary.com; "
"connect-src 'self'; "
"frame-ancestors 'self' https://huggingface.co https://*.hf.space; "
"base-uri 'self'; "
"form-action 'self'"
)
response.headers['Content-Security-Policy'] = csp
# Referrer policy
response.headers['Referrer-Policy'] = 'strict-origin-when-cross-origin'
# Feature policy / Permissions policy
response.headers['Permissions-Policy'] = (
'geolocation=(), microphone=(), camera=(), usb=(), payment=()'
)
# HSTS (HTTP Strict-Transport-Security) - only on HTTPS
if request.is_secure or os.environ.get('FLASK_ENV') == 'production':
response.headers['Strict-Transport-Security'] = 'max-age=31536000; includeSubDomains'
return response
# Session security
app.config.update(
SESSION_COOKIE_SECURE=os.environ.get('FLASK_ENV') == 'production',
SESSION_COOKIE_HTTPONLY=True,
SESSION_COOKIE_SAMESITE='Lax',
PERMANENT_SESSION_LIFETIME=timedelta(days=30),
)
logger.info("Security headers and features initialized")
def sanitize_filename(filename: str, max_length: int = 255) -> str:
"""
Sanitize filename to prevent directory traversal and other attacks.
Args:
filename: The filename to sanitize
max_length: Maximum length for the sanitized filename
Returns:
Safe filename
"""
import re
# Remove any path components
filename = os.path.basename(filename)
# Remove null bytes
filename = filename.replace('\0', '')
# Allow only safe characters (alphanumeric, dash, underscore, dot)
filename = re.sub(r'[^\w\-\.]', '_', filename)
# Remove leading/trailing dots and spaces
filename = filename.strip('. ')
# Prevent empty filename
if not filename:
filename = 'file'
# Limit length
if len(filename) > max_length:
# Preserve extension
name, ext = os.path.splitext(filename)
filename = name[:max_length - len(ext)] + ext
return filename
def validate_file_extension(filename: str, allowed_extensions: list) -> bool:
"""
Validate that file has an allowed extension.
Args:
filename: The filename to validate
allowed_extensions: List of allowed extensions (without dots)
Returns:
True if extension is allowed, False otherwise
"""
if not filename or '.' not in filename:
return False
ext = filename.rsplit('.', 1)[-1].lower()
return ext in [e.lower() for e in allowed_extensions]
def mask_sensitive_data(data: dict, fields_to_mask: list) -> dict:
"""
Mask sensitive fields in a dictionary before logging or sending to client.
Args:
data: Dictionary containing data to mask
fields_to_mask: List of field names to mask
Returns:
Dictionary with masked fields
"""
import copy
masked = copy.deepcopy(data)
for field in fields_to_mask:
if field in masked:
value = str(masked[field])
if len(value) > 4:
masked[field] = value[:2] + '*' * (len(value) - 4) + value[-2:]
else:
masked[field] = '*' * len(value)
return masked
def get_client_info() -> dict:
"""Extract client information from request for logging"""
return {
'ip_address': request.remote_addr,
'user_agent': request.headers.get('User-Agent', 'Unknown'),
'endpoint': request.endpoint,
'method': request.method,
'timestamp': now_ist().isoformat()
}
class RateLimiter:
"""Simple in-memory rate limiter for protecting against abuse"""
def __init__(self, max_requests: int = 100, window_seconds: int = 60):
self.max_requests = max_requests
self.window_seconds = window_seconds
self.requests = {} # {key: [(timestamp, count), ...]}
def is_rate_limited(self, key: str) -> bool:
"""Check if a key has exceeded rate limit"""
now = now_ist()
window_start = now - timedelta(seconds=self.window_seconds)
# Clean old entries
if key in self.requests:
self.requests[key] = [
(ts, count) for ts, count in self.requests[key]
if ts > window_start
]
# Count requests in window
total_requests = sum(count for _, count in self.requests.get(key, []))
return total_requests >= self.max_requests
def record_request(self, key: str):
"""Record a request for rate limiting"""
now = now_ist()
if key not in self.requests:
self.requests[key] = []
# Add or increment the count for this second
if self.requests[key] and self.requests[key][-1][0] == now:
ts, count = self.requests[key][-1]
self.requests[key][-1] = (ts, count + 1)
else:
self.requests[key].append((now, 1))
# Global rate limiter instances
login_rate_limiter = RateLimiter(max_requests=5, window_seconds=300) # 5 attempts in 5 minutes
upload_rate_limiter = RateLimiter(max_requests=20, window_seconds=3600) # 20 uploads per hour
def check_login_rate_limit(identifier: str) -> tuple[bool, str]:
"""
Check if login attempt should be rate limited.
Returns (is_limited, message)
"""
if login_rate_limiter.is_rate_limited(identifier):
return True, "Too many login attempts. Please try again later."
login_rate_limiter.record_request(identifier)
return False, ""
def check_upload_rate_limit(user_id: int) -> tuple[bool, str]:
"""
Check if upload should be rate limited.
Returns (is_limited, message)
"""
key = f"upload_{user_id}"
if upload_rate_limiter.is_rate_limited(key):
return True, "Upload rate limit exceeded. Maximum 20 uploads per hour."
upload_rate_limiter.record_request(key)
return False, ""