Coild / auth_middleware.py
loko-dev's picture
Rename 'doc' route to 'docs' and update references in templates; remove unused doc.css file
474257f
from functools import wraps
from flask import session, redirect, url_for, request, jsonify, make_response, current_app
from pocketbase import PocketBase
import logging
import os
import time
import jwt
from datetime import datetime, timedelta
from collections import defaultdict
logger = logging.getLogger(__name__)
RATE_LIMIT = 100 # requests per minute
rate_limit_data = defaultdict(list)
# Add this to track token refresh operations
refresh_attempts = defaultdict(int)
def is_rate_limited(ip):
now = time.time()
minute_ago = now - 60
# Clean old entries
rate_limit_data[ip] = [t for t in rate_limit_data[ip] if t > minute_ago]
# Check limit
if len(rate_limit_data[ip]) >= RATE_LIMIT:
return True
rate_limit_data[ip].append(now)
return False
def rate_limit_middleware():
ip = request.remote_addr
if is_rate_limited(ip):
return jsonify({'error': 'Rate limit exceeded'}), 429
def create_access_token(user_data, expires_delta=timedelta(minutes=60)):
"""Create a new access token"""
payload = {
'user_id': user_data.get('id'),
'email': user_data.get('email', ''),
'role': user_data.get('role', 'user'),
'exp': datetime.utcnow() + expires_delta
}
secret = current_app.config.get('JWT_SECRET_KEY') or current_app.secret_key
return jwt.encode(payload, secret, algorithm='HS256')
def create_refresh_token(user_data, expires_delta=timedelta(days=30)):
"""Create a new refresh token"""
payload = {
'user_id': user_data.get('id'),
'token_type': 'refresh',
'exp': datetime.utcnow() + expires_delta
}
secret = current_app.config.get('JWT_SECRET_KEY') or current_app.secret_key
return jwt.encode(payload, secret, algorithm='HS256')
def validate_token(token):
"""Validate JWT token"""
try:
secret = current_app.config.get('JWT_SECRET_KEY') or current_app.secret_key
payload = jwt.decode(token, secret, algorithms=['HS256'])
return payload, None
except jwt.ExpiredSignatureError:
return None, "Token has expired"
except jwt.InvalidTokenError:
return None, "Invalid token"
def init_auth(app):
if not os.getenv('ENABLE_AUTH', 'true').lower() == 'true':
logger.info("Authentication disabled")
return
# Set JWT secret key
app.config['JWT_SECRET_KEY'] = os.getenv('JWT_SECRET_KEY', app.secret_key)
# Initialize PocketBase client
pb = PocketBase(os.getenv('POCKETBASE_URL'))
app.pb = pb
@app.before_request
def before_request():
# Skip check if auth is disabled
if not os.getenv('ENABLE_AUTH', 'true').lower() == 'true':
return
# Skip auth for these paths
if request.endpoint in ['static', 'login', 'privacy', 'auth_callback', 'token_refresh', 'favicon', 'docs']:
return
# Check for access token in session
access_token = session.get('access_token')
if not access_token:
# If no access token, redirect to login
if request.is_json:
return jsonify({'error': 'Authentication required', 'code': 'AUTH_REQUIRED'}), 401
return redirect(url_for('login'))
# Validate the access token
payload, error = validate_token(access_token)
if error:
# If token is expired but we have a refresh token, try refresh
if error == "Token has expired" and session.get('refresh_token'):
try:
# Prevent excessive refresh attempts
ip = request.remote_addr
if refresh_attempts[ip] > 5: # Max 5 refresh attempts per minute
session.clear()
refresh_attempts[ip] = 0
return redirect(url_for('login'))
refresh_attempts[ip] += 1
# Try to refresh token
return redirect(url_for('token_refresh', next=request.path))
except Exception as e:
logger.error(f"Token refresh error: {str(e)}")
session.clear()
return redirect(url_for('login'))
else:
# Invalid or expired token with no refresh token
session.clear()
return redirect(url_for('login'))
# Restore PocketBase auth if token exists in session
pb_token = session.get('user', {}).get('token')
if pb_token and not app.pb.auth_store.token:
try:
# Restore PocketBase authentication state
app.pb.auth_store.save(pb_token, None)
logger.debug("Restored PocketBase authentication from session")
except Exception as e:
logger.warning(f"Failed to restore PocketBase auth: {e}")
# Continue - we'll handle PB errors in the routes
# Set user data from validated token
if not session.get('user') or session['user'].get('id') != payload['user_id']:
try:
# Try to fetch user from PocketBase
user = app.pb.collection('users').get_one(payload['user_id'])
# Store minimal user data in session
session['user'] = {
'id': payload['user_id'],
'email': payload['email'],
'role': payload['role'],
# Add other fields as needed
}
except Exception as e:
logger.error(f"Error fetching user: {str(e)}")
# Use token data if PocketBase is unavailable
session['user'] = {
'id': payload['user_id'],
'email': payload['email'],
'role': payload['role'],
}
# Always ensure session is permanent
session.permanent = True
# Clean up old refresh attempts periodically
now = time.time()
if now % 60 < 1: # Roughly once a minute
old_ips = [ip for ip, count in refresh_attempts.items() if count > 0]
for ip in old_ips:
refresh_attempts[ip] = 0