File size: 6,361 Bytes
5a8e751 b3985a0 5a8e751 6e1c7b6 fc33d7c 7bf4296 b3985a0 5a8e751 fc33d7c b3985a0 7bf4296 fc33d7c 7bf4296 b3985a0 7bf4296 b3985a0 7bf4296 b3985a0 7bf4296 b3985a0 7bf4296 b3985a0 7bf4296 b3985a0 7bf4296 b3985a0 7bf4296 fc33d7c 02fa519 6e1c7b6 b3985a0 5964593 b3985a0 5964593 b3985a0 5964593 7bf4296 b3985a0 474257f b3985a0 51a00b8 b3985a0 7bf4296 b3985a0 c4c6b7b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 |
from functools import wraps
from flask import session, redirect, url_for, request, jsonify, make_response, current_app
from pocketbase import PocketBase
import logging
import os
import time
import jwt
from datetime import datetime, timedelta
from collections import defaultdict
logger = logging.getLogger(__name__)
RATE_LIMIT = 100 # requests per minute
rate_limit_data = defaultdict(list)
# Add this to track token refresh operations
refresh_attempts = defaultdict(int)
def is_rate_limited(ip):
now = time.time()
minute_ago = now - 60
# Clean old entries
rate_limit_data[ip] = [t for t in rate_limit_data[ip] if t > minute_ago]
# Check limit
if len(rate_limit_data[ip]) >= RATE_LIMIT:
return True
rate_limit_data[ip].append(now)
return False
def rate_limit_middleware():
ip = request.remote_addr
if is_rate_limited(ip):
return jsonify({'error': 'Rate limit exceeded'}), 429
def create_access_token(user_data, expires_delta=timedelta(minutes=60)):
"""Create a new access token"""
payload = {
'user_id': user_data.get('id'),
'email': user_data.get('email', ''),
'role': user_data.get('role', 'user'),
'exp': datetime.utcnow() + expires_delta
}
secret = current_app.config.get('JWT_SECRET_KEY') or current_app.secret_key
return jwt.encode(payload, secret, algorithm='HS256')
def create_refresh_token(user_data, expires_delta=timedelta(days=30)):
"""Create a new refresh token"""
payload = {
'user_id': user_data.get('id'),
'token_type': 'refresh',
'exp': datetime.utcnow() + expires_delta
}
secret = current_app.config.get('JWT_SECRET_KEY') or current_app.secret_key
return jwt.encode(payload, secret, algorithm='HS256')
def validate_token(token):
"""Validate JWT token"""
try:
secret = current_app.config.get('JWT_SECRET_KEY') or current_app.secret_key
payload = jwt.decode(token, secret, algorithms=['HS256'])
return payload, None
except jwt.ExpiredSignatureError:
return None, "Token has expired"
except jwt.InvalidTokenError:
return None, "Invalid token"
def init_auth(app):
if not os.getenv('ENABLE_AUTH', 'true').lower() == 'true':
logger.info("Authentication disabled")
return
# Set JWT secret key
app.config['JWT_SECRET_KEY'] = os.getenv('JWT_SECRET_KEY', app.secret_key)
# Initialize PocketBase client
pb = PocketBase(os.getenv('POCKETBASE_URL'))
app.pb = pb
@app.before_request
def before_request():
# Skip check if auth is disabled
if not os.getenv('ENABLE_AUTH', 'true').lower() == 'true':
return
# Skip auth for these paths
if request.endpoint in ['static', 'login', 'privacy', 'auth_callback', 'token_refresh', 'favicon', 'docs']:
return
# Check for access token in session
access_token = session.get('access_token')
if not access_token:
# If no access token, redirect to login
if request.is_json:
return jsonify({'error': 'Authentication required', 'code': 'AUTH_REQUIRED'}), 401
return redirect(url_for('login'))
# Validate the access token
payload, error = validate_token(access_token)
if error:
# If token is expired but we have a refresh token, try refresh
if error == "Token has expired" and session.get('refresh_token'):
try:
# Prevent excessive refresh attempts
ip = request.remote_addr
if refresh_attempts[ip] > 5: # Max 5 refresh attempts per minute
session.clear()
refresh_attempts[ip] = 0
return redirect(url_for('login'))
refresh_attempts[ip] += 1
# Try to refresh token
return redirect(url_for('token_refresh', next=request.path))
except Exception as e:
logger.error(f"Token refresh error: {str(e)}")
session.clear()
return redirect(url_for('login'))
else:
# Invalid or expired token with no refresh token
session.clear()
return redirect(url_for('login'))
# Restore PocketBase auth if token exists in session
pb_token = session.get('user', {}).get('token')
if pb_token and not app.pb.auth_store.token:
try:
# Restore PocketBase authentication state
app.pb.auth_store.save(pb_token, None)
logger.debug("Restored PocketBase authentication from session")
except Exception as e:
logger.warning(f"Failed to restore PocketBase auth: {e}")
# Continue - we'll handle PB errors in the routes
# Set user data from validated token
if not session.get('user') or session['user'].get('id') != payload['user_id']:
try:
# Try to fetch user from PocketBase
user = app.pb.collection('users').get_one(payload['user_id'])
# Store minimal user data in session
session['user'] = {
'id': payload['user_id'],
'email': payload['email'],
'role': payload['role'],
# Add other fields as needed
}
except Exception as e:
logger.error(f"Error fetching user: {str(e)}")
# Use token data if PocketBase is unavailable
session['user'] = {
'id': payload['user_id'],
'email': payload['email'],
'role': payload['role'],
}
# Always ensure session is permanent
session.permanent = True
# Clean up old refresh attempts periodically
now = time.time()
if now % 60 < 1: # Roughly once a minute
old_ips = [ip for ip, count in refresh_attempts.items() if count > 0]
for ip in old_ips:
refresh_attempts[ip] = 0
|