Spaces:
Paused
Paused
from typing import Annotated | |
from sqlalchemy.orm import Session | |
from contextlib import asynccontextmanager | |
from starlette.middleware.cors import CORSMiddleware | |
from fastapi.security import OAuth2PasswordRequestForm | |
from fastapi import APIRouter, FastAPI, HTTPException, Depends | |
import torch | |
from ip_adapter import IPAdapter | |
from transformers import AutoModelForImageSegmentation | |
from diffusers import ( | |
AutoPipelineForText2Image, | |
DPMSolverMultistepScheduler, | |
StableDiffusionInstructPix2PixPipeline, | |
EulerAncestralDiscreteScheduler, | |
StableDiffusionPipeline, | |
DDIMScheduler, | |
AutoencoderKL, | |
) | |
from app.api import user | |
from app.api import prompt | |
from app.api import generate | |
from app.db import get_db, init_db | |
from app.config import ACCESS_TOKEN_EXPIRE_MINUTES, DEVICE | |
from app.security import authenticate_user, create_access_token, timedelta | |
async def lifespan(app: FastAPI): | |
base_model_path = "runwayml/stable-diffusion-v1-5" | |
vae_model_path = "stabilityai/sd-vae-ft-mse" | |
image_encoder_path = "models/image_encoder/" | |
ip_ckpt = "models/ip-adapter_sd15.bin" | |
ti_pipe = AutoPipelineForText2Image.from_pretrained( | |
"lykon/dreamshaper-xl-v2-turbo", torch_dtype=torch.float16, variant="fp16" | |
) | |
ti_pipe.scheduler = DPMSolverMultistepScheduler.from_config( | |
ti_pipe.scheduler.config | |
) | |
ti_pipe.to(DEVICE) | |
ti_pipe.enable_xformers_memory_efficient_attention() | |
ii_pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained( | |
"timbrooks/instruct-pix2pix", torch_dtype=torch.float16, safety_checker=None | |
) | |
ii_pipe.scheduler = EulerAncestralDiscreteScheduler.from_config( | |
ii_pipe.scheduler.config | |
) | |
ii_pipe.to(DEVICE) | |
ii_pipe.enable_xformers_memory_efficient_attention() | |
br_model = AutoModelForImageSegmentation.from_pretrained( | |
"briaai/RMBG-1.4", trust_remote_code=True | |
) | |
br_model.to(DEVICE) | |
noise_scheduler = DDIMScheduler( | |
num_train_timesteps=1000, | |
beta_start=0.00085, | |
beta_end=0.012, | |
beta_schedule="scaled_linear", | |
clip_sample=False, | |
set_alpha_to_one=False, | |
steps_offset=1, | |
) | |
vae = AutoencoderKL.from_pretrained(vae_model_path).to(dtype=torch.float16) | |
pipe = StableDiffusionPipeline.from_pretrained( | |
base_model_path, | |
torch_dtype=torch.float16, | |
scheduler=noise_scheduler, | |
vae=vae, | |
feature_extractor=None, | |
safety_checker=None, | |
) | |
iv_model = IPAdapter(pipe, image_encoder_path, ip_ckpt, DEVICE) | |
yield { | |
"ti_pipe": ti_pipe, | |
"ii_pipe": ii_pipe, | |
"br_model": br_model, | |
"iv_model": iv_model, | |
} | |
del ti_pipe | |
del ii_pipe | |
del br_model | |
del vae | |
del pipe | |
del iv_model | |
app = FastAPI(lifespan=lifespan) | |
# app = FastAPI() | |
router = APIRouter() | |
origins = ["*"] | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=origins, | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
init_db() | |
def read_root(): | |
return {"Hello": "World"} | |
async def login( | |
form_data: Annotated[OAuth2PasswordRequestForm, Depends()], | |
db: Annotated[Session, Depends(get_db)], | |
): | |
user = authenticate_user(db, form_data.username, form_data.password) | |
if not user: | |
raise HTTPException(status_code=400, detail="Incorrect username or password") | |
access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) | |
access_token = create_access_token( | |
data={"sub": user.username}, expires_delta=access_token_expires | |
) | |
return {"access_token": access_token, "token_type": "bearer"} | |
router.include_router(user.router, prefix="/users", tags=["users"]) | |
router.include_router(prompt.router, prefix="/prompts", tags=["prompts"]) | |
router.include_router(generate.router, prefix="/generate", tags=["generate"]) | |
app.include_router(router) | |