|
from functools import wraps |
|
from flask import session, redirect, url_for, request, jsonify, make_response, current_app |
|
from pocketbase import PocketBase |
|
import logging |
|
import os |
|
import time |
|
import jwt |
|
from datetime import datetime, timedelta |
|
from collections import defaultdict |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
RATE_LIMIT = 100 |
|
rate_limit_data = defaultdict(list) |
|
|
|
|
|
refresh_attempts = defaultdict(int) |
|
|
|
def is_rate_limited(ip): |
|
now = time.time() |
|
minute_ago = now - 60 |
|
|
|
|
|
rate_limit_data[ip] = [t for t in rate_limit_data[ip] if t > minute_ago] |
|
|
|
|
|
if len(rate_limit_data[ip]) >= RATE_LIMIT: |
|
return True |
|
|
|
rate_limit_data[ip].append(now) |
|
return False |
|
|
|
def rate_limit_middleware(): |
|
ip = request.remote_addr |
|
if is_rate_limited(ip): |
|
return jsonify({'error': 'Rate limit exceeded'}), 429 |
|
|
|
def create_access_token(user_data, expires_delta=timedelta(minutes=60)): |
|
"""Create a new access token""" |
|
payload = { |
|
'user_id': user_data.get('id'), |
|
'email': user_data.get('email', ''), |
|
'role': user_data.get('role', 'user'), |
|
'exp': datetime.utcnow() + expires_delta |
|
} |
|
secret = current_app.config.get('JWT_SECRET_KEY') or current_app.secret_key |
|
return jwt.encode(payload, secret, algorithm='HS256') |
|
|
|
def create_refresh_token(user_data, expires_delta=timedelta(days=30)): |
|
"""Create a new refresh token""" |
|
payload = { |
|
'user_id': user_data.get('id'), |
|
'token_type': 'refresh', |
|
'exp': datetime.utcnow() + expires_delta |
|
} |
|
secret = current_app.config.get('JWT_SECRET_KEY') or current_app.secret_key |
|
return jwt.encode(payload, secret, algorithm='HS256') |
|
|
|
def validate_token(token): |
|
"""Validate JWT token""" |
|
try: |
|
secret = current_app.config.get('JWT_SECRET_KEY') or current_app.secret_key |
|
payload = jwt.decode(token, secret, algorithms=['HS256']) |
|
return payload, None |
|
except jwt.ExpiredSignatureError: |
|
return None, "Token has expired" |
|
except jwt.InvalidTokenError: |
|
return None, "Invalid token" |
|
|
|
def init_auth(app): |
|
if not os.getenv('ENABLE_AUTH', 'true').lower() == 'true': |
|
logger.info("Authentication disabled") |
|
return |
|
|
|
|
|
app.config['JWT_SECRET_KEY'] = os.getenv('JWT_SECRET_KEY', app.secret_key) |
|
|
|
|
|
pb = PocketBase(os.getenv('POCKETBASE_URL')) |
|
app.pb = pb |
|
|
|
@app.before_request |
|
def before_request(): |
|
|
|
if not os.getenv('ENABLE_AUTH', 'true').lower() == 'true': |
|
return |
|
|
|
|
|
if request.endpoint in ['static', 'login', 'privacy', 'auth_callback', 'token_refresh', 'favicon', 'docs']: |
|
return |
|
|
|
|
|
access_token = session.get('access_token') |
|
|
|
if not access_token: |
|
|
|
if request.is_json: |
|
return jsonify({'error': 'Authentication required', 'code': 'AUTH_REQUIRED'}), 401 |
|
return redirect(url_for('login')) |
|
|
|
|
|
payload, error = validate_token(access_token) |
|
|
|
if error: |
|
|
|
if error == "Token has expired" and session.get('refresh_token'): |
|
try: |
|
|
|
ip = request.remote_addr |
|
if refresh_attempts[ip] > 5: |
|
session.clear() |
|
refresh_attempts[ip] = 0 |
|
return redirect(url_for('login')) |
|
|
|
refresh_attempts[ip] += 1 |
|
|
|
|
|
return redirect(url_for('token_refresh', next=request.path)) |
|
except Exception as e: |
|
logger.error(f"Token refresh error: {str(e)}") |
|
session.clear() |
|
return redirect(url_for('login')) |
|
else: |
|
|
|
session.clear() |
|
return redirect(url_for('login')) |
|
|
|
|
|
pb_token = session.get('user', {}).get('token') |
|
if pb_token and not app.pb.auth_store.token: |
|
try: |
|
|
|
app.pb.auth_store.save(pb_token, None) |
|
logger.debug("Restored PocketBase authentication from session") |
|
except Exception as e: |
|
logger.warning(f"Failed to restore PocketBase auth: {e}") |
|
|
|
|
|
|
|
if not session.get('user') or session['user'].get('id') != payload['user_id']: |
|
try: |
|
|
|
user = app.pb.collection('users').get_one(payload['user_id']) |
|
|
|
|
|
session['user'] = { |
|
'id': payload['user_id'], |
|
'email': payload['email'], |
|
'role': payload['role'], |
|
|
|
} |
|
except Exception as e: |
|
logger.error(f"Error fetching user: {str(e)}") |
|
|
|
session['user'] = { |
|
'id': payload['user_id'], |
|
'email': payload['email'], |
|
'role': payload['role'], |
|
} |
|
|
|
|
|
session.permanent = True |
|
|
|
|
|
now = time.time() |
|
if now % 60 < 1: |
|
old_ips = [ip for ip, count in refresh_attempts.items() if count > 0] |
|
for ip in old_ips: |
|
refresh_attempts[ip] = 0 |
|
|
|
|