SuhasGholkar commited on
Commit
d946b01
·
verified ·
1 Parent(s): 55f9cb2

Create auth.py

Browse files
Files changed (1) hide show
  1. src/auth.py +393 -0
src/auth.py ADDED
@@ -0,0 +1,393 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # src/auth.py
2
+ import os, re, time, secrets, uuid, sqlite3
3
+ from datetime import datetime, timedelta
4
+ import streamlit as st
5
+ import bcrypt
6
+ from email_validator import validate_email as _validate_email, EmailNotValidError
7
+
8
+ from src.utils import get_connection
9
+
10
+ # Config
11
+ LOCKOUT_MINUTES = int(os.getenv("LOCKOUT_MINUTES", "15"))
12
+ PASSWORD_EXPIRY_DAYS = int(os.getenv("PASSWORD_EXPIRY_DAYS", "90"))
13
+ PASSWORD_HISTORY_SIZE = int(os.getenv("PASSWORD_HISTORY_SIZE", "3"))
14
+ SESSION_TIMEOUT_MINUTES = int(os.getenv("SESSION_TIMEOUT_MINUTES", "30"))
15
+ PEPPER = os.getenv("PEPPER", "") # set this in HF Secrets
16
+
17
+ # Password policy
18
+ PASS_MIN_LEN = 8
19
+ _PASS_UPPER = re.compile(r"[A-Z]")
20
+ _PASS_LOWER = re.compile(r"[a-z]")
21
+ _PASS_DIGIT = re.compile(r"\d")
22
+ _PASS_SPECIAL = re.compile(r"[^A-Za-z0-9]")
23
+
24
+ def password_policy_errors(pw: str) -> list[str]:
25
+ errs = []
26
+ if len(pw) < PASS_MIN_LEN: errs.append(f"Min length {PASS_MIN_LEN}")
27
+ if not _PASS_UPPER.search(pw): errs.append("At least 1 uppercase")
28
+ if not _PASS_LOWER.search(pw): errs.append("At least 1 lowercase")
29
+ if not _PASS_DIGIT.search(pw): errs.append("At least 1 digit")
30
+ if not _PASS_SPECIAL.search(pw): errs.append("At least 1 special char")
31
+ return errs
32
+
33
+ def normalize_email(email: str) -> str:
34
+ return (email or "").strip().lower()
35
+
36
+ def is_alnum(s: str) -> bool:
37
+ return bool(re.fullmatch(r"[A-Za-z0-9]+", s or ""))
38
+
39
+ def is_valid_username(u: str) -> bool:
40
+ return bool(re.fullmatch(r"[A-Za-z0-9_]{3,32}", u or ""))
41
+
42
+ # Hashing
43
+ def hash_password(pw: str) -> str:
44
+ raw = (pw + PEPPER).encode("utf-8")
45
+ return "bcrypt$" + bcrypt.hashpw(raw, bcrypt.gensalt(rounds=12)).decode("utf-8")
46
+
47
+ def verify_password(pw: str, stored: str) -> bool:
48
+ if not stored or not stored.startswith("bcrypt$"):
49
+ return False
50
+ hashed = stored[len("bcrypt$"):].encode("utf-8")
51
+ raw = (pw + PEPPER).encode("utf-8")
52
+ try:
53
+ return bcrypt.checkpw(raw, hashed)
54
+ except Exception:
55
+ return False
56
+
57
+ # Schema
58
+ def init_auth_schema():
59
+ with get_connection() as conn:
60
+ conn.execute("""
61
+ CREATE TABLE IF NOT EXISTS users (
62
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
63
+ username TEXT NOT NULL UNIQUE,
64
+ email TEXT NOT NULL UNIQUE,
65
+ customer_id TEXT NOT NULL,
66
+ mobile TEXT,
67
+ role TEXT NOT NULL CHECK (role IN ('admin','user')),
68
+ password_hash TEXT NOT NULL,
69
+ password_algo TEXT NOT NULL DEFAULT 'bcrypt',
70
+ password_changed_at TEXT,
71
+ failed_attempts INTEGER NOT NULL DEFAULT 0,
72
+ locked_until TEXT,
73
+ is_active INTEGER NOT NULL DEFAULT 1,
74
+ last_login_at TEXT,
75
+ created_at TEXT NOT NULL,
76
+ updated_at TEXT NOT NULL
77
+ )""")
78
+ conn.execute("CREATE INDEX IF NOT EXISTS idx_users_email ON users(email)")
79
+
80
+ conn.execute("""
81
+ CREATE TABLE IF NOT EXISTS user_password_history (
82
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
83
+ user_id INTEGER NOT NULL,
84
+ password_hash TEXT NOT NULL,
85
+ changed_at TEXT NOT NULL,
86
+ FOREIGN KEY(user_id) REFERENCES users(id) ON DELETE CASCADE
87
+ )""")
88
+ conn.execute("CREATE INDEX IF NOT EXISTS idx_uph_user ON user_password_history(user_id)")
89
+
90
+ conn.execute("""
91
+ CREATE TABLE IF NOT EXISTS login_audit (
92
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
93
+ timestamp TEXT NOT NULL,
94
+ email TEXT NOT NULL,
95
+ user_id INTEGER,
96
+ result TEXT NOT NULL, -- success | failure | locked
97
+ reason TEXT,
98
+ client_ip TEXT,
99
+ user_agent TEXT,
100
+ session_id TEXT
101
+ )""")
102
+ conn.commit()
103
+
104
+ def bootstrap_admin_from_env():
105
+ with get_connection() as conn:
106
+ count = conn.execute("SELECT COUNT(*) FROM users").fetchone()[0]
107
+ if count > 0:
108
+ return
109
+ email = normalize_email(os.getenv("ADMIN_EMAIL", ""))
110
+ pw = os.getenv("ADMIN_PASSWORD", "")
111
+ name = os.getenv("ADMIN_NAME", "admin")
112
+ cust = os.getenv("ADMIN_CUSTOMER_ID", "ADMIN1")
113
+
114
+ if not email or not pw:
115
+ st.warning("Admin bootstrap: set ADMIN_EMAIL and ADMIN_PASSWORD for first-run admin.")
116
+ return
117
+ try:
118
+ _validate_email(email)
119
+ except EmailNotValidError:
120
+ st.warning("Admin bootstrap: invalid ADMIN_EMAIL.")
121
+ return
122
+ if not is_valid_username(name):
123
+ st.warning("Admin bootstrap: ADMIN_NAME must be 3–32 [A-Za-z0-9_].")
124
+ return
125
+ if not is_alnum(cust):
126
+ st.warning("Admin bootstrap: ADMIN_CUSTOMER_ID must be alphanumeric.")
127
+ return
128
+ errs = password_policy_errors(pw)
129
+ if errs:
130
+ st.warning("Admin bootstrap password policy: " + ", ".join(errs))
131
+ return
132
+
133
+ now = datetime.utcnow().isoformat(timespec="seconds")
134
+ pwh = hash_password(pw)
135
+ conn.execute("""
136
+ INSERT INTO users (username, email, customer_id, role, password_hash,
137
+ password_algo, password_changed_at, created_at, updated_at, is_active)
138
+ VALUES (?, ?, ?, 'admin', ?, 'bcrypt', ?, ?, ?, 1)
139
+ """, (name, email, cust, pwh, now, now, now))
140
+ uid = conn.execute("SELECT id FROM users WHERE email=?", (email,)).fetchone()[0]
141
+ conn.execute("INSERT INTO user_password_history (user_id, password_hash, changed_at) VALUES (?, ?, ?)",
142
+ (uid, pwh, now))
143
+ conn.commit()
144
+
145
+ # Session helpers
146
+ def set_session_user(row: sqlite3.Row):
147
+ st.session_state["_sid"] = st.session_state.get("_sid") or str(uuid.uuid4())
148
+ st.session_state["user"] = {
149
+ "id": row["id"],
150
+ "username": row["username"],
151
+ "email": row["email"],
152
+ "role": row["role"],
153
+ "customer_id": row["customer_id"],
154
+ "last_activity": time.time(),
155
+ "password_changed_at": row["password_changed_at"],
156
+ }
157
+
158
+ def session_active() -> bool:
159
+ u = st.session_state.get("user")
160
+ if not u:
161
+ return False
162
+ idle = time.time() - u.get("last_activity", 0)
163
+ if idle > SESSION_TIMEOUT_MINUTES * 60:
164
+ st.session_state.pop("user", None)
165
+ return False
166
+ u["last_activity"] = time.time()
167
+ st.session_state["user"] = u
168
+ return True
169
+
170
+ def session_countdown_widget():
171
+ u = st.session_state.get("user")
172
+ if not u:
173
+ return
174
+ idle = time.time() - u.get("last_activity", 0)
175
+ remain = max(0, SESSION_TIMEOUT_MINUTES*60 - int(idle))
176
+ mins, secs = divmod(remain, 60)
177
+ st.sidebar.caption(f"⏳ Session expires in {mins:02d}:{secs:02d}")
178
+
179
+ # CAPTCHA
180
+ def gen_captcha():
181
+ a, b = secrets.randbelow(8) + 2, secrets.randbelow(8) + 2
182
+ st.session_state["_captcha"] = (a, b, a + b)
183
+
184
+ def check_captcha(ans: str) -> bool:
185
+ a, b, s = st.session_state.get("_captcha", (0,0,0))
186
+ try:
187
+ return int(ans) == s
188
+ except Exception:
189
+ return False
190
+
191
+ # Queries
192
+ def user_by_email(email: str) -> sqlite3.Row | None:
193
+ with get_connection() as conn:
194
+ conn.row_factory = sqlite3.Row
195
+ return conn.execute("SELECT * FROM users WHERE email=?", (email,)).fetchone()
196
+
197
+ def email_exists(email: str) -> bool:
198
+ with get_connection() as conn:
199
+ return bool(conn.execute("SELECT 1 FROM users WHERE email=?", (email,)).fetchone())
200
+
201
+ def username_exists(username: str) -> bool:
202
+ with get_connection() as conn:
203
+ return bool(conn.execute("SELECT 1 FROM users WHERE username=?", (username,)).fetchone())
204
+
205
+ def within_lock(row: sqlite3.Row) -> bool:
206
+ val = row["locked_until"]
207
+ if not val: return False
208
+ try:
209
+ return datetime.utcnow() < datetime.fromisoformat(val)
210
+ except Exception:
211
+ return False
212
+
213
+ def is_password_expired(row: sqlite3.Row) -> bool:
214
+ changed = row["password_changed_at"]
215
+ if not changed: return True
216
+ try:
217
+ return datetime.utcnow() - datetime.fromisoformat(changed) > timedelta(days=PASSWORD_EXPIRY_DAYS)
218
+ except Exception:
219
+ return True
220
+
221
+ def record_login_audit(email: str, result: str, reason: str = "", user_id: int | None = None):
222
+ with get_connection() as conn:
223
+ conn.execute("""
224
+ INSERT INTO login_audit (timestamp, email, user_id, result, reason, client_ip, user_agent, session_id)
225
+ VALUES (?, ?, ?, ?, ?, ?, ?, ?)
226
+ """, (
227
+ datetime.utcnow().isoformat(timespec="seconds"),
228
+ email, user_id, result, reason, None, None, st.session_state.get("_sid")
229
+ ))
230
+ conn.commit()
231
+
232
+ def cannot_reuse_password(user_id: int, plaintext_pw: str) -> bool:
233
+ """Check plaintext+PEPPER against last N bcrypt hashes."""
234
+ with get_connection() as conn:
235
+ rows = conn.execute("""
236
+ SELECT password_hash FROM user_password_history
237
+ WHERE user_id=? ORDER BY changed_at DESC LIMIT ?
238
+ """, (user_id, PASSWORD_HISTORY_SIZE)).fetchall()
239
+ raw = (plaintext_pw + PEPPER).encode("utf-8")
240
+ for (stored,) in rows:
241
+ if not stored.startswith("bcrypt$"):
242
+ continue
243
+ if bcrypt.checkpw(raw, stored[len("bcrypt$"):].encode("utf-8")):
244
+ return True
245
+ return False
246
+
247
+ # UI
248
+ def render_login_signup_gate():
249
+ st.title("🔐 Sign in to Olist")
250
+ c1, c2 = st.columns(2)
251
+
252
+ # LOGIN
253
+ with c1:
254
+ st.subheader("Login")
255
+ with st.form("login_form"):
256
+ email = normalize_email(st.text_input("Email"))
257
+ pw = st.text_input("Password", type="password")
258
+ if "_captcha" not in st.session_state: gen_captcha()
259
+ a, b, _ = st.session_state.get("_captcha", (0,0,0))
260
+ cap = st.text_input(f"CAPTCHA: What is {a} + {b}?")
261
+ ok = st.form_submit_button("Login")
262
+ if ok:
263
+ if not check_captcha(cap):
264
+ st.error("CAPTCHA incorrect.")
265
+ gen_captcha()
266
+ return False
267
+ if not email or not pw:
268
+ st.error("Email and password required.")
269
+ return False
270
+ row = user_by_email(email)
271
+ if not row:
272
+ record_login_audit(email, "failure", "no_user", None)
273
+ st.error("Invalid credentials.")
274
+ gen_captcha(); return False
275
+ if int(row["is_active"]) != 1:
276
+ record_login_audit(email, "failure", "disabled", row["id"])
277
+ st.error("Account disabled. Contact admin.")
278
+ return False
279
+ if within_lock(row):
280
+ record_login_audit(email, "locked", "locked", row["id"])
281
+ st.error(f"Account locked until {row['locked_until']}.")
282
+ return False
283
+ if not verify_password(pw, row["password_hash"]):
284
+ with get_connection() as conn:
285
+ fa = int(row["failed_attempts"]) + 1
286
+ locked_until = None
287
+ if fa >= 3:
288
+ locked_until = (datetime.utcnow() + timedelta(minutes=LOCKOUT_MINUTES)).isoformat(timespec="seconds")
289
+ fa = 0
290
+ conn.execute(
291
+ "UPDATE users SET failed_attempts=?, locked_until=?, updated_at=? WHERE id=?",
292
+ (fa, locked_until, datetime.utcnow().isoformat(timespec="seconds"), row["id"])
293
+ )
294
+ conn.commit()
295
+ record_login_audit(email, "failure", "bad_password", row["id"])
296
+ if locked_until:
297
+ st.error(f"Too many attempts. Locked until {locked_until}.")
298
+ else:
299
+ st.error(f"Invalid credentials. Attempts left: {3 - (int(row['failed_attempts']) + 1)}")
300
+ gen_captcha(); return False
301
+
302
+ # not reached
303
+ # success
304
+ with get_connection() as conn:
305
+ conn.execute("UPDATE users SET failed_attempts=0, locked_until=NULL, last_login_at=?, updated_at=? WHERE id=?",
306
+ (datetime.utcnow().isoformat(timespec="seconds"), datetime.utcnow().isoformat(timespec="seconds"), row["id"]))
307
+ conn.commit()
308
+ record_login_audit(email, "success", "", row["id"])
309
+ set_session_user(row)
310
+
311
+ if is_password_expired(row):
312
+ st.warning("Password expired. Please update.")
313
+ st.session_state["_force_change_pw"] = True
314
+ return False
315
+
316
+ return True
317
+
318
+ # SIGNUP
319
+ with c2:
320
+ st.subheader("Self-Signup")
321
+ with st.form("signup_form", clear_on_submit=True):
322
+ username = st.text_input("Username (3–32, letters/digits/_)")
323
+ email2 = normalize_email(st.text_input("Email"))
324
+ cust = st.text_input("Customer ID (alphanumeric)")
325
+ pw1 = st.text_input("Password", type="password")
326
+ pw2 = st.text_input("Confirm Password", type="password")
327
+ ok2 = st.form_submit_button("Create Account")
328
+ if ok2:
329
+ try:
330
+ _validate_email(email2)
331
+ except EmailNotValidError:
332
+ st.error("Invalid email format."); return False
333
+ if not is_valid_username(username):
334
+ st.error("Username must be 3–32 and letters/digits/_ only."); return False
335
+ if not is_alnum(cust):
336
+ st.error("Customer ID must be alphanumeric."); return False
337
+ if email_exists(email2):
338
+ st.error("Email already in use."); return False
339
+ if username_exists(username):
340
+ st.error("Username already in use."); return False
341
+ if pw1 != pw2:
342
+ st.error("Passwords do not match."); return False
343
+ errs = password_policy_errors(pw1)
344
+ if errs:
345
+ st.error("Password policy: " + ", ".join(errs)); return False
346
+
347
+ now = datetime.utcnow().isoformat(timespec="seconds")
348
+ pwh = hash_password(pw1)
349
+ with get_connection() as conn:
350
+ conn.execute("""
351
+ INSERT INTO users (username, email, customer_id, role, password_hash,
352
+ password_algo, password_changed_at, created_at, updated_at, is_active)
353
+ VALUES (?, ?, ?, 'user', ?, 'bcrypt', ?, ?, ?, 1)
354
+ """, (username, email2, cust, pwh, now, now, now))
355
+ uid = conn.execute("SELECT id FROM users WHERE email=?", (email2,)).fetchone()[0]
356
+ conn.execute("INSERT INTO user_password_history (user_id, password_hash, changed_at) VALUES (?, ?, ?)",
357
+ (uid, pwh, now))
358
+ conn.commit()
359
+ st.success("Account created. Please log in.")
360
+ return False
361
+
362
+ return False
363
+
364
+ def render_force_change_password():
365
+ st.subheader("Set a new password")
366
+ with st.form("change_pw", clear_on_submit=True):
367
+ p1 = st.text_input("New Password", type="password")
368
+ p2 = st.text_input("Confirm Password", type="password")
369
+ ok = st.form_submit_button("Update Password")
370
+ if ok:
371
+ errs = password_policy_errors(p1)
372
+ if p1 != p2:
373
+ st.error("Passwords do not match."); return False
374
+ if errs:
375
+ st.error("Password policy: " + ", ".join(errs)); return False
376
+ u = st.session_state.get("user")
377
+ if not u:
378
+ st.error("Session missing. Please log in again."); return False
379
+ if cannot_reuse_password(u["id"], p1):
380
+ st.error(f"Cannot reuse any of your last {PASSWORD_HISTORY_SIZE} passwords.")
381
+ return False
382
+ now = datetime.utcnow().isoformat(timespec="seconds")
383
+ pwh = hash_password(p1)
384
+ with get_connection() as conn:
385
+ conn.execute("UPDATE users SET password_hash=?, password_changed_at=?, updated_at=? WHERE id=?",
386
+ (pwh, now, now, u["id"]))
387
+ conn.execute("INSERT INTO user_password_history (user_id, password_hash, changed_at) VALUES (?, ?, ?)",
388
+ (u["id"], pwh, now))
389
+ conn.commit()
390
+ st.session_state.pop("_force_change_pw", None)
391
+ st.success("Password updated. Continue to the app.")
392
+ return True
393
+ return False