from datetime import datetime |
from typing import Annotated |
from sqlmodel import Session, select |
from fastapi import Depends, HTTPException, status |
from fastapi.security import OAuth2PasswordBearer |
from jose import JWTError, jwt |
import core.utils as utils |
from models import User, UserCreate, Site, SiteCreate, Guest |
from models import TokenData |
from config import settings |
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="login") |
def authenticate_user(username: str, password: str): |
with Session(utils.engine) as session: |
statement = select(User).where(User.username == username) |
user = session.exec(statement).first() |
if not user: |
return False |
if not utils.verify_password(password, user.password): |
return False |
if user.disabled: |
raise HTTPException(status_code=400, detail="Inactive user") |
return user |
def get_user(username: str): |
with Session(utils.engine) as session: |
statement = select(User).where(User.username == username) |
user = session.exec(statement).first() |
if not user: |
raise HTTPException(status_code=404, detail="User not found") |
return user |
def get_current_user(token: Annotated[str, Depends(oauth2_scheme)]): |
credentials_exception = HTTPException( |
status_code=status.HTTP_401_UNAUTHORIZED, |
detail="Could not validate credentials", |
headers={"WWW-Authenticate": "Bearer"}, |
) |
try: |
payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]) |
username: str = payload.get("sub") |
if username is None: |
raise credentials_exception |
token_data = TokenData(username=username) |
except JWTError: |
raise credentials_exception |
user = get_user(username=token_data.username) |
if user is None: |
raise credentials_exception |
return user |
def get_current_active_user( |
current_user: Annotated[User, Depends(get_current_user)], |
): |
if current_user.disabled: |
raise HTTPException(status_code=400, detail="Inactive user") |
return current_user |
def get_current_super_user( |
current_user: Annotated[User, Depends(get_current_user)], |
): |
if not current_user.is_su: |
raise HTTPException(status_code=403, detail="Action only allowed for admin") |
return current_user |
def add_user(session: Session, user: UserCreate): |
statement = select(User).where(User.username == user.email) |
db_user = session.exec(statement).first() |
if not db_user: |
hashed_password = utils.get_password_hash(user.password) |
extra_data = {"password": hashed_password, |
"username": user.email, |
"created_at": datetime.now(), |
"updated_at": datetime.now()} |
db_user = User.model_validate(user, update=extra_data) |
session.add(db_user) |
session.commit() |
session.refresh(db_user) |
return db_user |
raise HTTPException(status_code=400, detail="Email already registered") |
def edit_user(session: Session, db_user: User, user): |
user_data = user.model_dump(exclude_unset=True) |
if "password" in user_data: |
hashed_password = utils.get_password_hash(user_data["password"]) |
user_data["password"] = hashed_password |
extra_data = {"updated_at": datetime.now()} |
db_user.sqlmodel_update(user_data, update=extra_data) |
try: |
session.add(db_user) |
session.commit() |
except Exception as e: |
raise HTTPException(status_code=400, detail="Update failed -> Hint: check for unique username") |
else: |
session.refresh(db_user) |
return db_user |
def camera_exists(session: Session, site): |
exists = None |
cameras = session.exec(select(Site.in_camera, Site.out_camera)).all() |
camera_list = [item for inner_tuple in cameras for item in inner_tuple if item is not None] |
if site.in_camera is not None: |
exists = "in_camera" if site.in_camera in camera_list else None |
if site.out_camera is not None: |
exists = "out_camera" if site.out_camera in camera_list else None |
if exists is not None: |
raise HTTPException(status_code=400, detail=f"Camera (Device ID) already exists in {exists}s") |
def push_site(session: Session, site: SiteCreate): |
try: |
session.add(site) |
session.commit() |
except Exception as e: |
raise HTTPException(status_code=400, detail="Action failed -> Hint: Check for unique site name") |
else: |
session.refresh(site) |
return site |
def get_current_site(session: Session, current_user: User, site_id: int): |
session.add(current_user) |
user_site_ids = [site.id for site in current_user.sites] |
if site_id not in user_site_ids: |
raise HTTPException(status_code=403, detail="Access only allowed for own sites") |
site = session.get(Site, site_id) |
if not site: |
raise HTTPException(status_code=404, detail="Site not found") |
return site |
def vector_exists(session: Session, guest): |
statement = select(Guest).where(Guest.vector == guest.vector) |
db_guest = session.exec(statement).first() |
if db_guest: |
raise HTTPException(status_code=400, detail=f"Guest/Host vector already exists at id {db_guest.id}") |
def get_host_of_site(session: Session, current_site: Site, host_id:int): |
site_host_ids = [host.id for host in current_site.hosts] |
if host_id not in site_host_ids: |
raise HTTPException(status_code=403, detail="Access only allowed for own hosts") |
host = session.get(Guest, host_id) |
if not host: |
raise HTTPException(status_code=404, detail="Host not found") |
return host |
def create_su(): |
with Session(utils.engine) as session: |
statement = select(User).where(User.username == settings.SU_NAME) |
su = session.exec(statement).first() |
if not su: |
su = User( |
username = settings.SU_NAME, |
password = utils.get_password_hash(settings.SU_PASSWORD), |
is_su=True, |
disabled=False |
) |
session.add(su) |
session.commit() |