cg-gen-ai / app /main.py
khawir's picture
Update app/main.py
81feb39 verified
from typing import Annotated
from sqlalchemy.orm import Session
from contextlib import asynccontextmanager
from starlette.middleware.cors import CORSMiddleware
from fastapi.security import OAuth2PasswordRequestForm
from fastapi import APIRouter, FastAPI, HTTPException, Depends
import torch
from ip_adapter import IPAdapter
from transformers import AutoModelForImageSegmentation
from diffusers import (
AutoPipelineForText2Image,
DPMSolverMultistepScheduler,
StableDiffusionInstructPix2PixPipeline,
EulerAncestralDiscreteScheduler,
StableDiffusionPipeline,
DDIMScheduler,
AutoencoderKL,
)
from app.api import user
from app.api import prompt
from app.api import generate
from app.db import get_db, init_db
from app.config import ACCESS_TOKEN_EXPIRE_MINUTES, DEVICE
from app.security import authenticate_user, create_access_token, timedelta
@asynccontextmanager
async def lifespan(app: FastAPI):
base_model_path = "runwayml/stable-diffusion-v1-5"
vae_model_path = "stabilityai/sd-vae-ft-mse"
image_encoder_path = "models/image_encoder/"
ip_ckpt = "models/ip-adapter_sd15.bin"
ti_pipe = AutoPipelineForText2Image.from_pretrained(
"lykon/dreamshaper-xl-v2-turbo", torch_dtype=torch.float16, variant="fp16"
)
ti_pipe.scheduler = DPMSolverMultistepScheduler.from_config(
ti_pipe.scheduler.config
)
ti_pipe.to(DEVICE)
ti_pipe.enable_xformers_memory_efficient_attention()
ii_pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained(
"timbrooks/instruct-pix2pix", torch_dtype=torch.float16, safety_checker=None
)
ii_pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(
ii_pipe.scheduler.config
)
ii_pipe.to(DEVICE)
ii_pipe.enable_xformers_memory_efficient_attention()
br_model = AutoModelForImageSegmentation.from_pretrained(
"briaai/RMBG-1.4", trust_remote_code=True
)
br_model.to(DEVICE)
noise_scheduler = DDIMScheduler(
num_train_timesteps=1000,
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
clip_sample=False,
set_alpha_to_one=False,
steps_offset=1,
)
vae = AutoencoderKL.from_pretrained(vae_model_path).to(dtype=torch.float16)
pipe = StableDiffusionPipeline.from_pretrained(
base_model_path,
torch_dtype=torch.float16,
scheduler=noise_scheduler,
vae=vae,
feature_extractor=None,
safety_checker=None,
)
iv_model = IPAdapter(pipe, image_encoder_path, ip_ckpt, DEVICE)
yield {
"ti_pipe": ti_pipe,
"ii_pipe": ii_pipe,
"br_model": br_model,
"iv_model": iv_model,
}
del ti_pipe
del ii_pipe
del br_model
del vae
del pipe
del iv_model
app = FastAPI(lifespan=lifespan)
# app = FastAPI()
router = APIRouter()
origins = ["*"]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
init_db()
@app.get("/")
def read_root():
return {"Hello": "World"}
@app.post("/token")
async def login(
form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
db: Annotated[Session, Depends(get_db)],
):
user = authenticate_user(db, form_data.username, form_data.password)
if not user:
raise HTTPException(status_code=400, detail="Incorrect username or password")
access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
access_token = create_access_token(
data={"sub": user.username}, expires_delta=access_token_expires
)
return {"access_token": access_token, "token_type": "bearer"}
router.include_router(user.router, prefix="/users", tags=["users"])
router.include_router(prompt.router, prefix="/prompts", tags=["prompts"])
router.include_router(generate.router, prefix="/generate", tags=["generate"])
app.include_router(router)