Spaces:
Paused
Paused
File size: 3,952 Bytes
c053e7d 23549a1 c053e7d 48c05f8 c053e7d 0e1b8df c053e7d 23549a1 c053e7d 6de967f 48c05f8 6de967f 48c05f8 81feb39 c053e7d 48c05f8 c053e7d 48c05f8 81feb39 c053e7d 48c05f8 c053e7d 23549a1 648acd4 23549a1 48c05f8 648acd4 23549a1 c053e7d 48c05f8 c053e7d 23549a1 648acd4 c053e7d 48c05f8 c053e7d 48c05f8 c053e7d 48c05f8 c053e7d 648acd4 c053e7d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
from typing import Annotated
from sqlalchemy.orm import Session
from contextlib import asynccontextmanager
from starlette.middleware.cors import CORSMiddleware
from fastapi.security import OAuth2PasswordRequestForm
from fastapi import APIRouter, FastAPI, HTTPException, Depends
import torch
from ip_adapter import IPAdapter
from transformers import AutoModelForImageSegmentation
from diffusers import (
AutoPipelineForText2Image,
DPMSolverMultistepScheduler,
StableDiffusionInstructPix2PixPipeline,
EulerAncestralDiscreteScheduler,
StableDiffusionPipeline,
DDIMScheduler,
AutoencoderKL,
)
from app.api import user
from app.api import prompt
from app.api import generate
from app.db import get_db, init_db
from app.config import ACCESS_TOKEN_EXPIRE_MINUTES, DEVICE
from app.security import authenticate_user, create_access_token, timedelta
@asynccontextmanager
async def lifespan(app: FastAPI):
base_model_path = "runwayml/stable-diffusion-v1-5"
vae_model_path = "stabilityai/sd-vae-ft-mse"
image_encoder_path = "models/image_encoder/"
ip_ckpt = "models/ip-adapter_sd15.bin"
ti_pipe = AutoPipelineForText2Image.from_pretrained(
"lykon/dreamshaper-xl-v2-turbo", torch_dtype=torch.float16, variant="fp16"
)
ti_pipe.scheduler = DPMSolverMultistepScheduler.from_config(
ti_pipe.scheduler.config
)
ti_pipe.to(DEVICE)
ti_pipe.enable_xformers_memory_efficient_attention()
ii_pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained(
"timbrooks/instruct-pix2pix", torch_dtype=torch.float16, safety_checker=None
)
ii_pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(
ii_pipe.scheduler.config
)
ii_pipe.to(DEVICE)
ii_pipe.enable_xformers_memory_efficient_attention()
br_model = AutoModelForImageSegmentation.from_pretrained(
"briaai/RMBG-1.4", trust_remote_code=True
)
br_model.to(DEVICE)
noise_scheduler = DDIMScheduler(
num_train_timesteps=1000,
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
clip_sample=False,
set_alpha_to_one=False,
steps_offset=1,
)
vae = AutoencoderKL.from_pretrained(vae_model_path).to(dtype=torch.float16)
pipe = StableDiffusionPipeline.from_pretrained(
base_model_path,
torch_dtype=torch.float16,
scheduler=noise_scheduler,
vae=vae,
feature_extractor=None,
safety_checker=None,
)
iv_model = IPAdapter(pipe, image_encoder_path, ip_ckpt, DEVICE)
yield {
"ti_pipe": ti_pipe,
"ii_pipe": ii_pipe,
"br_model": br_model,
"iv_model": iv_model,
}
del ti_pipe
del ii_pipe
del br_model
del vae
del pipe
del iv_model
app = FastAPI(lifespan=lifespan)
# app = FastAPI()
router = APIRouter()
origins = ["*"]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
init_db()
@app.get("/")
def read_root():
return {"Hello": "World"}
@app.post("/token")
async def login(
form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
db: Annotated[Session, Depends(get_db)],
):
user = authenticate_user(db, form_data.username, form_data.password)
if not user:
raise HTTPException(status_code=400, detail="Incorrect username or password")
access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
access_token = create_access_token(
data={"sub": user.username}, expires_delta=access_token_expires
)
return {"access_token": access_token, "token_type": "bearer"}
router.include_router(user.router, prefix="/users", tags=["users"])
router.include_router(prompt.router, prefix="/prompts", tags=["prompts"])
router.include_router(generate.router, prefix="/generate", tags=["generate"])
app.include_router(router)
|