Spaces:
Running
Running
| import bcrypt | |
| # Robust monkeypatch for bcrypt-passlib compatibility | |
| if not hasattr(bcrypt, "__about__"): | |
| bcrypt.__about__ = type('About', (object,), {'__version__': bcrypt.__version__}) | |
| from fastapi import FastAPI, WebSocket, WebSocketDisconnect, Request, Response, Depends, HTTPException, status | |
| from typing import Optional | |
| from fastapi.responses import HTMLResponse | |
| from fastapi.staticfiles import StaticFiles | |
| from fastapi.templating import Jinja2Templates | |
| import uvicorn | |
| import asyncio | |
| import os | |
| from sqlalchemy.ext.asyncio import AsyncSession | |
| from sqlalchemy.future import select | |
| from pydantic import BaseModel, EmailStr | |
| from utils.main import RAG | |
| from database import engine, Base, get_db | |
| from models import User | |
| from auth import get_password_hash, verify_password, create_access_token, ACCESS_TOKEN_EXPIRE_MINUTES, ALGORITHM, SECRET_KEY | |
| from email_utils import send_reset_email | |
| from datetime import timedelta | |
| from fastapi.security import OAuth2PasswordBearer | |
| from jose import JWTError, jwt | |
| import uuid | |
| from database import engine, Base, get_db, AsyncSessionLocal | |
| # Initialize Database and Create Default Admin | |
| async def init_db(): | |
| async with engine.begin() as conn: | |
| await conn.run_sync(Base.metadata.create_all) | |
| # Ensure Admin exists | |
| async with AsyncSessionLocal() as db: | |
| result = await db.execute(select(User).where(User.username == "admin")) | |
| admin_user = result.scalars().first() | |
| if not admin_user: | |
| print("Creating default admin user...") | |
| new_admin = User( | |
| username="admin", | |
| email="admin@lawbot.ai", | |
| hashed_password=get_password_hash("ADMin"), | |
| role="Admin" | |
| ) | |
| db.add(new_admin) | |
| await db.commit() | |
| print("Admin user created successfully.") | |
| app = FastAPI(on_startup=[init_db]) | |
| chat_history = [] | |
| # Get the base directory | |
| BASE_DIR = os.path.dirname(os.path.abspath(__file__)) | |
| STATIC_DIR = os.path.join(BASE_DIR, "static") | |
| TEMPLATES_DIR = os.path.join(BASE_DIR, "templates") | |
| # Mount static directories | |
| app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static") | |
| app.mount("/images", StaticFiles(directory=os.path.join(STATIC_DIR, "images")), name="images") | |
| # Initialize Jinja2 templates | |
| templates = Jinja2Templates(directory=TEMPLATES_DIR) | |
| # Pydantic Models for Auth | |
| class UserCreate(BaseModel): | |
| username: str | |
| email: EmailStr | |
| password: str | |
| role: str | |
| mobile_number: str # Exactly 10 digits required | |
| class UserLogin(BaseModel): | |
| username: str | |
| password: str | |
| role: Optional[str] = None | |
| class ForgotPassword(BaseModel): | |
| email: EmailStr | |
| class ResetPassword(BaseModel): | |
| token: str | |
| new_password: str | |
| class Interaction(BaseModel): | |
| caseId: str | |
| query: str | |
| response: str | |
| role: str # Role context for this interaction | |
| oauth2_scheme = OAuth2PasswordBearer(tokenUrl="api/login") | |
| async def get_current_user(request: Request, db: AsyncSession = Depends(get_db)): | |
| credentials_exception = HTTPException( | |
| status_code=status.HTTP_401_UNAUTHORIZED, | |
| detail="Could not validate credentials", | |
| headers={"WWW-Authenticate": "Bearer"}, | |
| ) | |
| # Try to get token from Header first (for API calls), then Cookie (for Page navigation) | |
| token = None | |
| auth_header = request.headers.get("Authorization") | |
| if auth_header and auth_header.startswith("Bearer "): | |
| token = auth_header.split(" ")[1] | |
| else: | |
| token = request.cookies.get("access_token") | |
| if not token: | |
| raise credentials_exception | |
| try: | |
| payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) | |
| username: str = payload.get("sub") | |
| if username is None: | |
| raise credentials_exception | |
| except JWTError: | |
| raise credentials_exception | |
| result = await db.execute(select(User).where(User.username == username)) | |
| user = result.scalars().first() | |
| if user is None: | |
| raise credentials_exception | |
| return user | |
| def role_required(required_role: str): | |
| async def role_checker(user: User = Depends(get_current_user)): | |
| if user.role != required_role and user.role != "Admin": | |
| raise HTTPException( | |
| status_code=status.HTTP_403_FORBIDDEN, | |
| detail=f"Access denied: Requires {required_role} role (or Admin)" | |
| ) | |
| return user | |
| return role_checker | |
| # Auth Routes | |
| async def register(user: UserCreate, db: AsyncSession = Depends(get_db)): | |
| result = await db.execute(select(User).where((User.username == user.username) | (User.email == user.email))) | |
| if result.scalars().first(): | |
| raise HTTPException(status_code=400, detail="Username or Email already registered") | |
| # Validation: Exactly 10 digits | |
| if not user.mobile_number.isdigit() or len(user.mobile_number) != 10: | |
| raise HTTPException(status_code=400, detail="Mobile number must be exactly 10 digits.") | |
| hashed_password = get_password_hash(user.password) | |
| new_user = User( | |
| username=user.username, | |
| email=user.email, | |
| hashed_password=hashed_password, | |
| role=user.role, | |
| mobile_number=user.mobile_number | |
| ) | |
| db.add(new_user) | |
| await db.commit() | |
| return {"message": "User created successfully"} | |
| async def login(response: Response, user: UserLogin, db: AsyncSession = Depends(get_db)): | |
| print(f"Login attempt for username: {user.username}") | |
| result = await db.execute(select(User).where(User.username == user.username)) | |
| db_user = result.scalars().first() | |
| if not db_user: | |
| print(f"User not found: {user.username}") | |
| raise HTTPException(status_code=400, detail="Incorrect username or password") | |
| print(f"User found: {db_user.username}, checking password...") | |
| password_valid = verify_password(user.password, db_user.hashed_password) | |
| print(f"Password valid: {password_valid}") | |
| if not password_valid: | |
| raise HTTPException(status_code=400, detail="Incorrect username or password") | |
| # Verify role if provided | |
| if user.role and db_user.role != user.role and db_user.role != "Admin": | |
| print(f"Role mismatch: user has {db_user.role}, but tried to login as {user.role}") | |
| raise HTTPException(status_code=403, detail=f"Access denied: This account is registered as {db_user.role}, not {user.role}") | |
| access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) | |
| access_token = create_access_token(data={"sub": db_user.username, "role": db_user.role}, expires_delta=access_token_expires) | |
| # Set secure cookie - use samesite="none" for Hugging Face iframes | |
| response.set_cookie( | |
| key="access_token", | |
| value=access_token, | |
| httponly=True, | |
| max_age=ACCESS_TOKEN_EXPIRE_MINUTES * 60, | |
| samesite="none", | |
| secure=True # Required for Hugging Face HTTPS and samesite="none" | |
| ) | |
| print(f"Login successful for {db_user.username}, role: {db_user.role}") | |
| return { | |
| "access_token": access_token, | |
| "token_type": "bearer", | |
| "role": db_user.role, | |
| "question_count": db_user.question_count, | |
| "is_admin": db_user.role == "Admin" | |
| } | |
| async def logout(response: Response): | |
| response.delete_cookie( | |
| key="access_token", | |
| samesite="none", | |
| secure=True | |
| ) | |
| return {"message": "Logged out successfully"} | |
| async def forgot_password(request: ForgotPassword, db: AsyncSession = Depends(get_db)): | |
| result = await db.execute(select(User).where(User.email == request.email)) | |
| user = result.scalars().first() | |
| if user: | |
| token = str(uuid.uuid4()) | |
| user.reset_token = token | |
| await db.commit() | |
| try: | |
| await send_reset_email(user.email, token) | |
| except Exception as e: | |
| print(f"Error sending email: {e}") | |
| raise HTTPException(status_code=500, detail=f"Failed to send email: {str(e)}") | |
| return {"message": "If an account exists, a reset email has been sent"} | |
| async def reset_password(request: ResetPassword, db: AsyncSession = Depends(get_db)): | |
| result = await db.execute(select(User).where(User.reset_token == request.token)) | |
| user = result.scalars().first() | |
| if not user: | |
| raise HTTPException(status_code=400, detail="Invalid or expired reset token") | |
| user.hashed_password = get_password_hash(request.new_password) | |
| user.reset_token = None | |
| await db.commit() | |
| return {"message": "Password reset successful"} | |
| async def save_interaction(interaction: Interaction, db: AsyncSession = Depends(get_db), current_user: User = Depends(get_current_user)): | |
| from models import ChatInteraction | |
| # Create new chat interaction record linked to user with role | |
| new_interaction = ChatInteraction( | |
| case_id=interaction.caseId, | |
| query=interaction.query, | |
| response=interaction.response, | |
| role=interaction.role, | |
| user_id=current_user.id | |
| ) | |
| db.add(new_interaction) | |
| await db.commit() | |
| await db.refresh(new_interaction) | |
| print(f"Saved interaction for user {current_user.username}: ID={new_interaction.id}, CaseID={interaction.caseId}") | |
| return {"status": "success", "message": "Interaction saved", "id": new_interaction.id} | |
| async def get_interactions(role: Optional[str] = None, db: AsyncSession = Depends(get_db), current_user: User = Depends(get_current_user)): | |
| from models import ChatInteraction | |
| from sqlalchemy import func | |
| # Build base where clause | |
| where_clause = ChatInteraction.user_id == current_user.id | |
| if role: | |
| where_clause = (ChatInteraction.user_id == current_user.id) & (ChatInteraction.role == role) | |
| # Subquery to find the latest interaction per case_id (filtered by role if specified) | |
| subquery = ( | |
| select( | |
| ChatInteraction.case_id, | |
| func.max(ChatInteraction.created_at).label("max_created") | |
| ) | |
| .where(where_clause) | |
| .group_by(ChatInteraction.case_id) | |
| .subquery() | |
| ) | |
| # Join to get query detail for the latest message in each case | |
| result = await db.execute( | |
| select(ChatInteraction) | |
| .join(subquery, (ChatInteraction.case_id == subquery.c.case_id) & (ChatInteraction.created_at == subquery.c.max_created)) | |
| .order_by(ChatInteraction.created_at.desc()) | |
| .limit(20) | |
| ) | |
| interactions = result.scalars().all() | |
| return [{"id": i.id, "case_id": i.case_id, "query": i.query, "created_at": i.created_at} for i in interactions] | |
| async def get_conversation_thread(case_id: str, role: Optional[str] = None, db: AsyncSession = Depends(get_db), current_user: User = Depends(get_current_user)): | |
| from models import ChatInteraction | |
| # Build where clause with optional role filtering | |
| where_clause = (ChatInteraction.user_id == current_user.id) & (ChatInteraction.case_id == case_id) | |
| if role: | |
| where_clause = where_clause & (ChatInteraction.role == role) | |
| result = await db.execute( | |
| select(ChatInteraction) | |
| .where(where_clause) | |
| .order_by(ChatInteraction.created_at.asc()) | |
| ) | |
| interactions = result.scalars().all() | |
| if not interactions: | |
| raise HTTPException(status_code=404, detail="Conversation not found") | |
| return [{"query": i.query, "response": i.response, "created_at": i.created_at} for i in interactions] | |
| async def delete_conversation(case_id: str, db: AsyncSession = Depends(get_db), current_user: User = Depends(get_current_user)): | |
| from models import ChatInteraction | |
| from sqlalchemy import delete as sqlalchemy_delete | |
| await db.execute( | |
| sqlalchemy_delete(ChatInteraction) | |
| .where((ChatInteraction.user_id == current_user.id) & (ChatInteraction.case_id == case_id)) | |
| ) | |
| await db.commit() | |
| return {"status": "success", "message": "Conversation deleted"} | |
| async def get_user_status(current_user: User = Depends(get_current_user)): | |
| return { | |
| "username": current_user.username, | |
| "role": current_user.role, | |
| "question_count": current_user.question_count, | |
| "limit": 2 if current_user.role != "Admin" else None, | |
| "is_admin": current_user.role == "Admin" | |
| } | |
| async def get_all_users(db: AsyncSession = Depends(get_db), current_user: User = Depends(role_required("Admin"))): | |
| result = await db.execute(select(User)) | |
| users = result.scalars().all() | |
| return [ | |
| { | |
| "id": u.id, | |
| "username": u.username, | |
| "email": u.email, | |
| "role": u.role, | |
| "mobile_number": u.mobile_number, | |
| "question_count": u.question_count, | |
| "is_blocked": False # Placeholder if needed | |
| } for u in users | |
| ] | |
| # Frontend Routes for Auth | |
| async def login_page(request: Request): | |
| return templates.TemplateResponse("login.html", {"request": request}) | |
| async def register_page(request: Request): | |
| return templates.TemplateResponse("register.html", {"request": request}) | |
| async def forgot_password_page(request: Request): | |
| return templates.TemplateResponse("forgot_password.html", {"request": request}) | |
| async def reset_password_page(request: Request): | |
| return templates.TemplateResponse("reset_password.html", {"request": request}) | |
| async def admin_users_page(request: Request): | |
| # Removing API-level Depends() here so the HTML page loads and handles auth via JS fetch() | |
| return templates.TemplateResponse("admin_users.html", {"request": request}) | |
| # Home and Role Selection | |
| async def role_selection(request: Request): | |
| return templates.TemplateResponse("roleselection.html", {"request": request, "show_roles": False}) | |
| async def roleselection_page(request: Request): | |
| return templates.TemplateResponse("roleselection.html", {"request": request, "show_roles": True}) | |
| # Chatbot Pages | |
| async def judge_chatbot(request: Request, user: User = Depends(role_required("Judge"))): | |
| return templates.TemplateResponse("Judgechatbot.html", {"request": request}) | |
| async def judge_dashboard(request: Request, user: User = Depends(role_required("Judge"))): | |
| return templates.TemplateResponse("judgedashboard.html", {"request": request}) | |
| async def view_all(request: Request, user: User = Depends(role_required("Judge"))): | |
| return templates.TemplateResponse("viewall.html", {"request": request}) | |
| async def judge_calender(request: Request, user: User = Depends(role_required("Judge"))): | |
| return templates.TemplateResponse("judgecalender.html", {"request": request}) | |
| async def advocate_dashboard(request: Request, user: User = Depends(role_required("Advocate/Lawyer"))): | |
| return templates.TemplateResponse("advocatedashboard.html", {"request": request}) | |
| async def advocate_resources(request: Request, user: User = Depends(role_required("Advocate/Lawyer"))): | |
| return templates.TemplateResponse("advocateresources.html", {"request": request}) | |
| # ========== Other Role Pages ========== | |
| async def woman_page(request: Request, user: User = Depends(role_required("Woman"))): | |
| return templates.TemplateResponse("woman.html", {"request": request}) | |
| async def citizen_page(request: Request, user: User = Depends(role_required("Citizen"))): | |
| return templates.TemplateResponse("citizen.html", {"request": request}) | |
| async def minor_page(request: Request, user: User = Depends(role_required("Minor"))): | |
| return templates.TemplateResponse("minor.html", {"request": request}) | |
| # ========== Chatbot Pages ========== | |
| async def student_page(request: Request, user: User = Depends(role_required("Student"))): | |
| return templates.TemplateResponse("studentchatbot.html", {"request": request}) | |
| async def advocatechatbot_page(request: Request): | |
| return templates.TemplateResponse("advocatechatbot.html", {"request": request}) | |
| async def womanchatbot_page(request: Request): | |
| return templates.TemplateResponse("womanchatbot.html", {"request": request}) | |
| async def safetytips_page(request: Request): | |
| return templates.TemplateResponse("safetytips.html", {"request": request}) | |
| async def resources_page(request: Request): | |
| return templates.TemplateResponse("resources.html", {"request": request}) | |
| async def legalrights_page(request: Request): | |
| return templates.TemplateResponse("legalrights.html", {"request": request}) | |
| async def fir_page(request: Request): | |
| return templates.TemplateResponse("FIR.html", {"request": request}) | |
| async def student_page(request: Request): | |
| return templates.TemplateResponse("studentdashboard.html", {"request": request}) | |
| # WebSocket for Chatbot | |
| async def stream_text_conversational(websocket: WebSocket, query: str, role: str = "General"): | |
| chat_limit = 10 | |
| temp_chat = {"user": "" ,"system":""} | |
| temp_chat["user"] = query | |
| # print(f"DEBUG: stream_text_conversational executing with role={role}") | |
| model_response = "" | |
| try: | |
| if role == "Citizen": | |
| completion = RAG(query, chat_history, role=role) | |
| else: | |
| completion = RAG(query, chat_history, role=role) | |
| for chunk in completion: | |
| if chunk.choices[0].delta.content is not None: | |
| await websocket.send_text(chunk.choices[0].delta.content) | |
| await asyncio.sleep(0.01) | |
| model_response += chunk.choices[0].delta.content | |
| # Signal completion to frontend | |
| await websocket.send_text("[DONE]") | |
| # print(model_response) | |
| temp_chat['system']=model_response | |
| chat_history.append(temp_chat) | |
| if len(chat_history)>chat_limit: | |
| chat_history.pop(0) | |
| except Exception as e: | |
| error_str = str(e) | |
| print(f"Chatbot error: {error_str}") | |
| # Show a professional message — never dump raw API errors to users | |
| if "404" in error_str or "No endpoints" in error_str: | |
| user_msg = "\n\n⚠️ **Service Temporarily Unavailable**\nThe AI provider is currently unavailable. Please send your message again — the system will automatically switch to an alternative model.\n\nIf the issue persists, please try again in a moment." | |
| else: | |
| user_msg = f"\n\n⚠️ **An error occurred.** Please try sending your message again.\n\n_If the issue persists, please contact support._" | |
| await websocket.send_text(user_msg) | |
| await websocket.send_text("[DONE]") | |
| async def conversational_chat(websocket: WebSocket, role: Optional[str] = None): | |
| await websocket.accept() | |
| # --- STRICT ROLE ENFORCEMENT --- | |
| # 1. Get role from Token/Cookie (Source of Truth) | |
| token = websocket.cookies.get("access_token") | |
| token_role = "General" | |
| user_id = None | |
| question_count = 0 | |
| if token: | |
| try: | |
| payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) | |
| username = payload.get("sub") | |
| if username: | |
| async for db in get_db(): | |
| result = await db.execute(select(User).where(User.username == username)) | |
| user_obj = result.scalars().first() | |
| if user_obj: | |
| token_role = user_obj.role | |
| user_id = user_obj.id | |
| question_count = user_obj.question_count | |
| break | |
| except JWTError: | |
| pass | |
| if user_id is None: | |
| await websocket.send_text("\n\n⚠️ Authentication Required. Please sign in to use the chatbot.") | |
| await websocket.close() | |
| return | |
| # 2. Apply Logic based on Source of Truth | |
| if token_role == "Admin": | |
| if not role: | |
| role = "Admin" | |
| else: | |
| role = token_role | |
| while True: | |
| try: | |
| query = await websocket.receive_text() | |
| # --- USAGE LIMIT CHECK --- | |
| if token_role != "Admin" and question_count >= 2: | |
| limit_message = ( | |
| "### ✨ Free usage limit reached\n\n" | |
| "You've reached the free usage limit (2 questions).\n" | |
| "Further access is restricted.\n\n" | |
| "Please contact the administrator for extended access:\n" | |
| "LinkedIn: https://www.linkedin.com/in/vishwanath77" | |
| ) | |
| await websocket.send_text(limit_message) | |
| continue | |
| print(f"Query ({role}): {query}") | |
| await stream_text_conversational(websocket, query, role=role) | |
| # --- INCREMENT USAGE --- | |
| if token_role != "Admin": | |
| async for db in get_db(): | |
| result = await db.execute(select(User).where(User.id == user_id)) | |
| user_to_update = result.scalars().first() | |
| if user_to_update: | |
| user_to_update.question_count += 1 | |
| question_count = user_to_update.question_count # Keep local sync | |
| await db.commit() | |
| break | |
| except WebSocketDisconnect: | |
| chat_history.clear() | |
| break | |
| except Exception as e: | |
| print(f"WebSocket error: {e}") | |
| try: | |
| await websocket.send_text( | |
| "\n\n⚠️ **Connection interrupted.** Please try sending your message again." | |
| ) | |
| await websocket.send_text("[DONE]") | |
| except: | |
| break | |
| # Run the application | |
| if __name__ == "__main__": | |
| print("Starting Law Bot Server...") | |
| port = int(os.getenv("PORT", 8000)) | |
| uvicorn.run(app, host="0.0.0.0", port=port) | |