|
from fastapi import FastAPI, Depends, HTTPException, Request, Form, status |
|
from fastapi.responses import RedirectResponse, HTMLResponse, JSONResponse |
|
from fastapi.templating import Jinja2Templates |
|
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm |
|
from pydantic import BaseModel |
|
from sqlalchemy.orm import Session |
|
from database import get_db, get_user_by_email |
|
from models import User |
|
from passlib.context import CryptContext |
|
from datetime import datetime, timedelta |
|
import jwt |
|
from emailx import send_verification_email, generate_verification_token |
|
from fastapi.staticfiles import StaticFiles |
|
from typing import Optional |
|
import httpx |
|
import os |
|
from starlette.middleware.sessions import SessionMiddleware |
|
from authlib.integrations.starlette_client import OAuth |
|
|
|
|
|
GOOGLE_CLIENT_ID = os.getenv('GOOGLE_CLIENT_ID') |
|
GOOGLE_CLIENT_SECRET = os.getenv('GOOGLE_CLIENT_SECRET') |
|
SECRET_KEY = os.getenv('SecretKey', 'default_secret') |
|
ALGORITHM = "HS256" |
|
ACCESS_TOKEN_EXPIRE_MINUTES = 30 |
|
|
|
|
|
app = FastAPI() |
|
app.add_middleware(SessionMiddleware, secret_key=SECRET_KEY) |
|
oauth = OAuth() |
|
|
|
|
|
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") |
|
|
|
|
|
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") |
|
|
|
class TokenData(BaseModel): |
|
token: str |
|
|
|
class UserCreate(BaseModel): |
|
username: str |
|
email: str |
|
password: str |
|
|
|
|
|
oauth.register( |
|
name='google', |
|
client_id=os.environ['GOOGLE_CLIENT_ID'], |
|
client_secret=os.environ['GOOGLE_CLIENT_SECRET'], |
|
access_token_url='https://accounts.google.com/o/oauth2/token', |
|
authorize_url='https://accounts.google.com/o/oauth2/auth', |
|
authorize_params=None, |
|
api_base_url='https://www.googleapis.com/oauth2/v1/', |
|
client_kwargs={'scope': 'openid email profile'} |
|
) |
|
|
|
|
|
app.mount("/static", StaticFiles(directory="static"), name="static") |
|
templates = Jinja2Templates(directory="templates") |
|
|
|
|
|
@app.post("/login") |
|
async def login( |
|
request: Request, |
|
form_data: OAuth2PasswordRequestForm = Depends(), |
|
db: Session = Depends(get_db), |
|
recaptcha_token: str = Form(...) |
|
): |
|
if not await verify_recaptcha(recaptcha_token): |
|
return templates.TemplateResponse("login.html", {"request": request, "error_message": "reCAPTCHA validation failed."}) |
|
|
|
user = authenticate_user(db, form_data.username, form_data.password) |
|
if not user: |
|
return templates.TemplateResponse("login.html", {"request": request, "error_message": "Invalid email or password"}) |
|
|
|
if not user.is_verified: |
|
return templates.TemplateResponse("login.html", {"request": request, "error_message": "Please verify your email before accessing this resource."}) |
|
|
|
access_token = create_access_token(data={"sub": user.email}, expires_delta=timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)) |
|
url = app.url_path_for("get_protected") |
|
response = RedirectResponse(url) |
|
response.set_cookie(key="access_token", value=f"Bearer {access_token}", httponly=True, secure=True, samesite='Lax') |
|
return response |
|
|
|
@app.get("/login", response_class=HTMLResponse) |
|
async def login_get(request: Request): |
|
google_oauth_url = request.url_for("login_oauth") |
|
return templates.TemplateResponse("login.html", {"request": request, "google_oauth_url": google_oauth_url}) |
|
|
|
@app.get("/login/oauth") |
|
async def login_oauth(request: Request): |
|
redirect_uri = request.url_for('auth_callback') |
|
return await oauth.google.authorize_redirect(request, redirect_uri) |
|
|
|
@app.get("/auth/callback") |
|
async def auth_callback(request: Request, db: Session = Depends(get_db)): |
|
try: |
|
token = await oauth.google.authorize_access_token(request) |
|
user_info = await oauth.google.parse_id_token(request, token) |
|
|
|
existing_user = db.query(User).filter(User.email == user_info['email']).first() |
|
if existing_user: |
|
access_token = create_access_token(data={"sub": existing_user.email}, expires_delta=timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)) |
|
response = RedirectResponse(url="/protected") |
|
response.set_cookie(key="access_token", value=f"Bearer {access_token}", httponly=True, secure=True, samesite='Lax') |
|
return response |
|
else: |
|
new_user = User(email=user_info['email'], username=user_info.get('name'), is_verified=True) |
|
db.add(new_user) |
|
db.commit() |
|
db.refresh(new_user) |
|
|
|
access_token = create_access_token(data={"sub": new_user.email}, expires_delta=timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)) |
|
response = RedirectResponse(url="/protected") |
|
response.set_cookie(key="access_token", value=f"Bearer {access_token}", httponly=True, secure=True, samesite='Lax') |
|
return response |
|
except Exception as e: |
|
print(f"OAuth exception: {e}") |
|
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="An error occurred during OAuth authentication") |
|
|
|
@app.get("/register/google") |
|
async def register_google(request: Request): |
|
redirect_uri = request.url_for('auth_callback') |
|
return await oauth.google.authorize_redirect(request, redirect_uri) |
|
|
|
@app.get("/registration_successful", response_class=HTMLResponse) |
|
async def registration_successful(request: Request): |
|
return templates.TemplateResponse("registration_successful.html", {"request": request}) |
|
|
|
|
|
async def verify_recaptcha(recaptcha_token: str) -> bool: |
|
recaptcha_secret = '6LeSJgwpAAAAAJrLrvlQYhRsOjf2wKXee_Jc4Z-k' |
|
recaptcha_url = 'https://www.google.com/recaptcha/api/siteverify' |
|
recaptcha_data = { |
|
'secret': recaptcha_secret, |
|
'response': recaptcha_token |
|
} |
|
|
|
async with httpx.AsyncClient() as client: |
|
recaptcha_response = await client.post(recaptcha_url, data=recaptcha_data) |
|
|
|
recaptcha_result = recaptcha_response.json() |
|
return recaptcha_result.get('success', False) |
|
|
|
@app.get("/verify", response_class=HTMLResponse) |
|
async def verify_email(request: Request, token: str, db: Session = Depends(get_db)): |
|
user = get_user_by_verification_token(db, token) |
|
if not user: |
|
return templates.TemplateResponse("verification_failed.html", {"request": request, "error_message": "Invalid verification token"}) |
|
|
|
if user.is_verified: |
|
return templates.TemplateResponse("verification_failed.html", {"request": request, "error_message": "Email already verified"}) |
|
|
|
user.is_verified = True |
|
user.email_verification_token = None |
|
db.commit() |
|
|
|
access_token = create_access_token(data={"sub": user.email}, expires_delta=timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)) |
|
response = RedirectResponse(url="/protected") |
|
response.set_cookie(key="access_token", value=f"Bearer {access_token}", httponly=True, secure=True, samesite='Lax') |
|
return response |
|
|
|
def is_username_available(db: Session, username: str) -> bool: |
|
return db.query(User).filter(User.username == username).first() is None |
|
|
|
@app.get("/register", response_class=HTMLResponse) |
|
async def register_get(request: Request): |
|
return templates.TemplateResponse("register.html", {"request": request, "google_oauth_url": request.url_for("login_oauth")}) |
|
|
|
@app.post("/register") |
|
async def register_post( |
|
request: Request, |
|
username: str = Form(...), |
|
email: str = Form(...), |
|
password: str = Form(...), |
|
confirm_password: str = Form(...), |
|
recaptcha_token: str = Form(...), |
|
db: Session = Depends(get_db) |
|
): |
|
if not await verify_recaptcha(recaptcha_token): |
|
return templates.TemplateResponse("register.html", {"request": request, "error_message": "reCAPTCHA validation failed."}) |
|
|
|
if password != confirm_password: |
|
return templates.TemplateResponse("register.html", {"request": request, "error_message": "Passwords do not match."}) |
|
|
|
user_data = UserCreate(username=username, email=email, password=password) |
|
if not is_username_available(db, user_data.username): |
|
return templates.TemplateResponse("register.html", {"request": request, "error_message": "Username already taken"}) |
|
|
|
try: |
|
registered_user = register_user(user_data, db) |
|
return RedirectResponse(url="/registration_successful", status_code=status.HTTP_303_SEE_OTHER) |
|
except HTTPException as e: |
|
return templates.TemplateResponse("register.html", {"request": request, "error_message": e.detail}) |
|
|
|
@app.get("/", response_class=HTMLResponse) |
|
async def landing(request: Request): |
|
return templates.TemplateResponse("landing.html", {"request": request}) |
|
|
|
def verify_password(plain_password, hashed_password): |
|
return pwd_context.verify(plain_password, hashed_password) |
|
|
|
def get_password_hash(password): |
|
return pwd_context.hash(password) |
|
|
|
def authenticate_user(db: Session, email: str, password: str): |
|
user = db.query(User).filter(User.email == email).first() |
|
if not user or not verify_password(password, user.hashed_password): |
|
return None |
|
return user |
|
|
|
def create_access_token(data: dict, expires_delta: timedelta = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)): |
|
to_encode = data.copy() |
|
expire = datetime.utcnow() + expires_delta |
|
to_encode.update({"exp": expire}) |
|
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) |
|
return encoded_jwt |
|
|
|
def verify_token(token: str): |
|
try: |
|
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) |
|
return payload.get("sub") |
|
except jwt.ExpiredSignatureError: |
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Token has expired") |
|
except jwt.InvalidTokenError: |
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Could not validate credentials") |
|
|
|
async def get_current_user(request: Request, db: Session = Depends(get_db)): |
|
token = request.cookies.get("access_token") |
|
if not token: |
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Not authenticated") |
|
|
|
try: |
|
email = verify_token(token.split(" ")[1]) |
|
user = get_user_by_email(db, email) |
|
if not user or not user.is_verified: |
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="User not found or not verified") |
|
return user |
|
except HTTPException: |
|
raise |
|
except Exception as e: |
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail=str(e)) |
|
|
|
@app.get("/protected", response_class=HTMLResponse) |
|
async def get_protected(request: Request, current_user: User = Depends(get_current_user)): |
|
return templates.TemplateResponse("protected.html", {"request": request, "user": current_user.username}) |
|
|
|
def register_user(user_data: UserCreate, db: Session): |
|
if get_user_by_email(db, user_data.email): |
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Email already registered") |
|
|
|
hashed_password = get_password_hash(user_data.password) |
|
verification_token = generate_verification_token(user_data.email) |
|
reset_link = f"http://gregniuki-loginauth.hf.space/verify?token={verification_token}" |
|
send_verification_email(user_data.email, reset_link) |
|
|
|
new_user = User( |
|
email=user_data.email, |
|
username=user_data.username, |
|
hashed_password=hashed_password, |
|
email_verification_token=verification_token |
|
) |
|
db.add(new_user) |
|
db.commit() |
|
db.refresh(new_user) |
|
return new_user |
|
|
|
def get_user_by_verification_token(db: Session, verification_token: str): |
|
return db.query(User).filter(User.email_verification_token == verification_token).first() |
|
|
|
def reset_password(user: User, db: Session): |
|
verification_token = generate_verification_token(user.email) |
|
reset_link = f"http://gregniuki-loginauth.hf.space/reset-password?token={verification_token}" |
|
send_verification_email(user.email, reset_link) |
|
|
|
user.email_verification_token = verification_token |
|
db.commit() |