Spaces:
Paused
Paused
from typing import Annotated | |
from sqlalchemy.orm import Session | |
from contextlib import asynccontextmanager | |
from starlette.middleware.cors import CORSMiddleware | |
from fastapi.security import OAuth2PasswordRequestForm | |
from fastapi import APIRouter, FastAPI, HTTPException, Depends | |
import torch | |
from ip_adapter import IPAdapter | |
from transformers import AutoModelForImageSegmentation | |
from diffusers import AutoPipelineForText2Image, DPMSolverMultistepScheduler, StableDiffusionInstructPix2PixPipeline, EulerAncestralDiscreteScheduler, StableDiffusionPipeline, DDIMScheduler, AutoencoderKL | |
from app.api import user | |
from app.api import prompt | |
from app.api import generate | |
from app.db import get_db, init_db | |
from app.config import ACCESS_TOKEN_EXPIRE_MINUTES, DEVICE | |
from app.security import authenticate_user, create_access_token, timedelta | |
async def lifespan(app: FastAPI): | |
base_model_path = "runwayml/stable-diffusion-v1-5" | |
vae_model_path = "stabilityai/sd-vae-ft-mse" | |
image_encoder_path = "models/image_encoder/" | |
ip_ckpt = "models/ip-adapter_sd15.bin" | |
ti_pipe = AutoPipelineForText2Image.from_pretrained( | |
'lykon/dreamshaper-xl-v2-turbo', torch_dtype=torch.float16, variant="fp16") | |
ti_pipe.to(DEVICE) | |
ti_pipe.scheduler = DPMSolverMultistepScheduler.from_config( | |
ti_pipe.scheduler.config) | |
ii_pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained( | |
"timbrooks/instruct-pix2pix", torch_dtype=torch.float16, safety_checker=None) | |
ii_pipe.to(DEVICE) | |
ii_pipe.scheduler = EulerAncestralDiscreteScheduler.from_config( | |
ii_pipe.scheduler.config) | |
br_model = AutoModelForImageSegmentation.from_pretrained( | |
"briaai/RMBG-1.4", trust_remote_code=True) | |
br_model.to(DEVICE) | |
noise_scheduler = DDIMScheduler( | |
num_train_timesteps=1000, | |
beta_start=0.00085, | |
beta_end=0.012, | |
beta_schedule="scaled_linear", | |
clip_sample=False, | |
set_alpha_to_one=False, | |
steps_offset=1, | |
) | |
vae = AutoencoderKL.from_pretrained(vae_model_path).to(dtype=torch.float16) | |
pipe = StableDiffusionPipeline.from_pretrained( | |
base_model_path, | |
torch_dtype=torch.float16, | |
scheduler=noise_scheduler, | |
vae=vae, | |
feature_extractor=None, | |
safety_checker=None | |
) | |
iv_model = IPAdapter(pipe, image_encoder_path, ip_ckpt, DEVICE) | |
yield {'ti_pipe': ti_pipe, 'ii_pipe': ii_pipe, 'br_model': br_model, 'iv_model': iv_model} | |
del ti_pipe | |
del ii_pipe | |
del br_model | |
del vae | |
del pipe | |
del iv_model | |
app = FastAPI(lifespan=lifespan) | |
# app = FastAPI() | |
router = APIRouter() | |
origins = ["*"] | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=origins, | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
init_db() | |
def read_root(): | |
return {"Hello": "World"} | |
async def login(form_data: Annotated[OAuth2PasswordRequestForm, Depends()], db: Annotated[Session, Depends(get_db)]): | |
user = authenticate_user( | |
db, form_data.username, form_data.password) | |
if not user: | |
raise HTTPException( | |
status_code=400, detail="Incorrect username or password") | |
access_token_expires = timedelta( | |
minutes=ACCESS_TOKEN_EXPIRE_MINUTES) | |
access_token = create_access_token( | |
data={"sub": user.username}, expires_delta=access_token_expires) | |
return {"access_token": access_token, "token_type": "bearer"} | |
router.include_router(user.router, prefix="/users", tags=["users"]) | |
router.include_router(prompt.router, prefix="/prompts", tags=["prompts"]) | |
router.include_router(generate.router, prefix="/generate", tags=["generate"]) | |
app.include_router(router) | |