|
from contextlib import asynccontextmanager |
|
|
|
from fastapi import Depends, FastAPI, HTTPException, Request |
|
from fastapi.middleware.cors import CORSMiddleware |
|
from fastapi.responses import FileResponse, RedirectResponse |
|
from fastapi.staticfiles import StaticFiles |
|
from huggingface_hub import OAuthInfo, attach_huggingface_oauth, parse_huggingface_oauth |
|
from sqlmodel import select |
|
|
|
from . import constants |
|
from .database import get_session, init_db |
|
from .schemas import UserCount |
|
|
|
|
|
|
|
@asynccontextmanager |
|
async def lifespan(app: FastAPI): |
|
init_db() |
|
yield |
|
|
|
|
|
|
|
app = FastAPI(lifespan=lifespan) |
|
|
|
|
|
|
|
app.add_middleware( |
|
CORSMiddleware, |
|
allow_origins=[ |
|
|
|
"http://localhost:5173", |
|
"http://0.0.0.0:9481", |
|
"http://localhost:9481", |
|
"http://127.0.0.1:9481", |
|
], |
|
allow_credentials=True, |
|
allow_methods=["*"], |
|
allow_headers=["*"], |
|
) |
|
|
|
|
|
if constants.SERVE_FRONTEND: |
|
|
|
app.mount( |
|
"/assets", |
|
StaticFiles(directory=constants.FRONTEND_ASSETS_PATH), |
|
name="assets", |
|
) |
|
|
|
@app.get("/") |
|
async def serve_frontend(): |
|
return FileResponse(constants.FRONTEND_INDEX_PATH) |
|
|
|
else: |
|
|
|
@app.get("/") |
|
async def redirect_to_frontend(): |
|
return RedirectResponse("http://localhost:5173/") |
|
|
|
|
|
|
|
|
|
attach_huggingface_oauth(app) |
|
|
|
|
|
async def oauth_info_optional(request: Request) -> OAuthInfo | None: |
|
return parse_huggingface_oauth(request) |
|
|
|
|
|
async def oauth_info_required(request: Request) -> OAuthInfo: |
|
oauth_info = parse_huggingface_oauth(request) |
|
if oauth_info is None: |
|
raise HTTPException( |
|
status_code=401, detail="Unauthorized. Please Sign in with Hugging Face." |
|
) |
|
return oauth_info |
|
|
|
|
|
|
|
@app.get("/api/health") |
|
async def health(): |
|
"""Health check endpoint.""" |
|
return {"status": "ok"} |
|
|
|
|
|
|
|
@app.get("/api/user") |
|
async def get_user(oauth_info: OAuthInfo | None = Depends(oauth_info_optional)): |
|
"""Get user information.""" |
|
return { |
|
"connected": oauth_info is not None, |
|
"username": oauth_info.user_info.preferred_username if oauth_info else None, |
|
} |
|
|
|
|
|
@app.get("/api/user/count") |
|
async def get_user_count( |
|
oauth_info: OAuthInfo = Depends(oauth_info_required), |
|
) -> UserCount: |
|
"""Get user count.""" |
|
with get_session() as session: |
|
statement = select(UserCount).where(UserCount.name == oauth_info.user_info.name) |
|
user_count = session.exec(statement).first() |
|
if user_count is None: |
|
user_count = UserCount(name=oauth_info.user_info.name, count=0) |
|
return user_count |
|
|
|
|
|
@app.post("/api/user/count/increment") |
|
async def increment_user_count( |
|
oauth_info: OAuthInfo = Depends(oauth_info_required), |
|
) -> UserCount: |
|
"""Increment user count.""" |
|
with get_session() as session: |
|
statement = select(UserCount).where(UserCount.name == oauth_info.user_info.name) |
|
user_count = session.exec(statement).first() |
|
if user_count is None: |
|
user_count = UserCount(name=oauth_info.user_info.name, count=0) |
|
|
|
user_count.count += 1 |
|
session.add(user_count) |
|
session.commit() |
|
session.refresh(user_count) |
|
return user_count |
|
|