File size: 6,361 Bytes
5a8e751
b3985a0
 
5a8e751
6e1c7b6
fc33d7c
7bf4296
 
b3985a0
5a8e751
 
 
fc33d7c
 
 
b3985a0
 
7bf4296
fc33d7c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7bf4296
b3985a0
7bf4296
b3985a0
7bf4296
 
 
 
b3985a0
7bf4296
 
 
b3985a0
7bf4296
b3985a0
7bf4296
 
 
b3985a0
7bf4296
 
 
 
 
b3985a0
7bf4296
 
 
 
 
 
fc33d7c
02fa519
6e1c7b6
 
 
b3985a0
 
 
5964593
b3985a0
 
 
 
5964593
b3985a0
 
 
5964593
7bf4296
b3985a0
474257f
b3985a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51a00b8
 
 
 
 
 
 
 
 
 
 
b3985a0
 
 
 
 
7bf4296
b3985a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c4c6b7b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
from functools import wraps
from flask import session, redirect, url_for, request, jsonify, make_response, current_app
from pocketbase import PocketBase
import logging
import os
import time
import jwt
from datetime import datetime, timedelta
from collections import defaultdict

logger = logging.getLogger(__name__)

RATE_LIMIT = 100  # requests per minute
rate_limit_data = defaultdict(list)

# Add this to track token refresh operations
refresh_attempts = defaultdict(int)

def is_rate_limited(ip):
    now = time.time()
    minute_ago = now - 60
    
    # Clean old entries
    rate_limit_data[ip] = [t for t in rate_limit_data[ip] if t > minute_ago]
    
    # Check limit
    if len(rate_limit_data[ip]) >= RATE_LIMIT:
        return True
        
    rate_limit_data[ip].append(now)
    return False

def rate_limit_middleware():
    ip = request.remote_addr
    if is_rate_limited(ip):
        return jsonify({'error': 'Rate limit exceeded'}), 429

def create_access_token(user_data, expires_delta=timedelta(minutes=60)):
    """Create a new access token"""
    payload = {
        'user_id': user_data.get('id'),
        'email': user_data.get('email', ''),
        'role': user_data.get('role', 'user'),
        'exp': datetime.utcnow() + expires_delta
    }
    secret = current_app.config.get('JWT_SECRET_KEY') or current_app.secret_key
    return jwt.encode(payload, secret, algorithm='HS256')

def create_refresh_token(user_data, expires_delta=timedelta(days=30)):
    """Create a new refresh token"""
    payload = {
        'user_id': user_data.get('id'),
        'token_type': 'refresh',
        'exp': datetime.utcnow() + expires_delta
    }
    secret = current_app.config.get('JWT_SECRET_KEY') or current_app.secret_key
    return jwt.encode(payload, secret, algorithm='HS256')

def validate_token(token):
    """Validate JWT token"""
    try:
        secret = current_app.config.get('JWT_SECRET_KEY') or current_app.secret_key
        payload = jwt.decode(token, secret, algorithms=['HS256'])
        return payload, None
    except jwt.ExpiredSignatureError:
        return None, "Token has expired"
    except jwt.InvalidTokenError:
        return None, "Invalid token"

def init_auth(app):
    if not os.getenv('ENABLE_AUTH', 'true').lower() == 'true':
        logger.info("Authentication disabled")
        return

    # Set JWT secret key
    app.config['JWT_SECRET_KEY'] = os.getenv('JWT_SECRET_KEY', app.secret_key)
    
    # Initialize PocketBase client
    pb = PocketBase(os.getenv('POCKETBASE_URL'))
    app.pb = pb

    @app.before_request
    def before_request():
        # Skip check if auth is disabled
        if not os.getenv('ENABLE_AUTH', 'true').lower() == 'true':
            return
        
        # Skip auth for these paths
        if request.endpoint in ['static', 'login', 'privacy', 'auth_callback', 'token_refresh', 'favicon', 'docs']:
            return

        # Check for access token in session
        access_token = session.get('access_token')
        
        if not access_token:
            # If no access token, redirect to login
            if request.is_json:
                return jsonify({'error': 'Authentication required', 'code': 'AUTH_REQUIRED'}), 401
            return redirect(url_for('login'))

        # Validate the access token
        payload, error = validate_token(access_token)
        
        if error:
            # If token is expired but we have a refresh token, try refresh
            if error == "Token has expired" and session.get('refresh_token'):
                try:
                    # Prevent excessive refresh attempts
                    ip = request.remote_addr
                    if refresh_attempts[ip] > 5:  # Max 5 refresh attempts per minute
                        session.clear()
                        refresh_attempts[ip] = 0
                        return redirect(url_for('login'))
                        
                    refresh_attempts[ip] += 1
                    
                    # Try to refresh token
                    return redirect(url_for('token_refresh', next=request.path))
                except Exception as e:
                    logger.error(f"Token refresh error: {str(e)}")
                    session.clear()
                    return redirect(url_for('login'))
            else:
                # Invalid or expired token with no refresh token
                session.clear()
                return redirect(url_for('login'))
        
        # Restore PocketBase auth if token exists in session
        pb_token = session.get('user', {}).get('token')
        if pb_token and not app.pb.auth_store.token:
            try:
                # Restore PocketBase authentication state
                app.pb.auth_store.save(pb_token, None)
                logger.debug("Restored PocketBase authentication from session")
            except Exception as e:
                logger.warning(f"Failed to restore PocketBase auth: {e}")
                # Continue - we'll handle PB errors in the routes
        
        # Set user data from validated token
        if not session.get('user') or session['user'].get('id') != payload['user_id']:
            try:
                # Try to fetch user from PocketBase
                user = app.pb.collection('users').get_one(payload['user_id'])
                
                # Store minimal user data in session
                session['user'] = {
                    'id': payload['user_id'],
                    'email': payload['email'],
                    'role': payload['role'],
                    # Add other fields as needed
                }
            except Exception as e:
                logger.error(f"Error fetching user: {str(e)}")
                # Use token data if PocketBase is unavailable
                session['user'] = {
                    'id': payload['user_id'],
                    'email': payload['email'],
                    'role': payload['role'],
                }

        # Always ensure session is permanent
        session.permanent = True
        
        # Clean up old refresh attempts periodically
        now = time.time()
        if now % 60 < 1:  # Roughly once a minute
            old_ips = [ip for ip, count in refresh_attempts.items() if count > 0]
            for ip in old_ips:
                refresh_attempts[ip] = 0