from typing import Annotated from sqlalchemy.orm import Session from contextlib import asynccontextmanager from starlette.middleware.cors import CORSMiddleware from fastapi.security import OAuth2PasswordRequestForm from fastapi import APIRouter, FastAPI, HTTPException, Depends import torch from ip_adapter import IPAdapter from transformers import AutoModelForImageSegmentation from diffusers import ( AutoPipelineForText2Image, DPMSolverMultistepScheduler, StableDiffusionInstructPix2PixPipeline, EulerAncestralDiscreteScheduler, StableDiffusionPipeline, DDIMScheduler, AutoencoderKL, ) from app.api import user from app.api import prompt from app.api import generate from app.db import get_db, init_db from app.config import ACCESS_TOKEN_EXPIRE_MINUTES, DEVICE from app.security import authenticate_user, create_access_token, timedelta @asynccontextmanager async def lifespan(app: FastAPI): base_model_path = "runwayml/stable-diffusion-v1-5" vae_model_path = "stabilityai/sd-vae-ft-mse" image_encoder_path = "models/image_encoder/" ip_ckpt = "models/ip-adapter_sd15.bin" ti_pipe = AutoPipelineForText2Image.from_pretrained( "lykon/dreamshaper-xl-v2-turbo", torch_dtype=torch.float16, variant="fp16" ) ti_pipe.scheduler = DPMSolverMultistepScheduler.from_config( ti_pipe.scheduler.config ) ti_pipe.to(DEVICE) ti_pipe.enable_xformers_memory_efficient_attention() ii_pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained( "timbrooks/instruct-pix2pix", torch_dtype=torch.float16, safety_checker=None ) ii_pipe.scheduler = EulerAncestralDiscreteScheduler.from_config( ii_pipe.scheduler.config ) ii_pipe.to(DEVICE) ii_pipe.enable_xformers_memory_efficient_attention() br_model = AutoModelForImageSegmentation.from_pretrained( "briaai/RMBG-1.4", trust_remote_code=True ) br_model.to(DEVICE) noise_scheduler = DDIMScheduler( num_train_timesteps=1000, beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False, steps_offset=1, ) vae = AutoencoderKL.from_pretrained(vae_model_path).to(dtype=torch.float16) pipe = StableDiffusionPipeline.from_pretrained( base_model_path, torch_dtype=torch.float16, scheduler=noise_scheduler, vae=vae, feature_extractor=None, safety_checker=None, ) iv_model = IPAdapter(pipe, image_encoder_path, ip_ckpt, DEVICE) yield { "ti_pipe": ti_pipe, "ii_pipe": ii_pipe, "br_model": br_model, "iv_model": iv_model, } del ti_pipe del ii_pipe del br_model del vae del pipe del iv_model app = FastAPI(lifespan=lifespan) # app = FastAPI() router = APIRouter() origins = ["*"] app.add_middleware( CORSMiddleware, allow_origins=origins, allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) init_db() @app.get("/") def read_root(): return {"Hello": "World"} @app.post("/token") async def login( form_data: Annotated[OAuth2PasswordRequestForm, Depends()], db: Annotated[Session, Depends(get_db)], ): user = authenticate_user(db, form_data.username, form_data.password) if not user: raise HTTPException(status_code=400, detail="Incorrect username or password") access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) access_token = create_access_token( data={"sub": user.username}, expires_delta=access_token_expires ) return {"access_token": access_token, "token_type": "bearer"} router.include_router(user.router, prefix="/users", tags=["users"]) router.include_router(prompt.router, prefix="/prompts", tags=["prompts"]) router.include_router(generate.router, prefix="/generate", tags=["generate"]) app.include_router(router)