Spaces:
Paused
Paused
File size: 3,755 Bytes
c053e7d 23549a1 c053e7d 23549a1 c053e7d 0e1b8df c053e7d 23549a1 c053e7d 6de967f c053e7d 6de967f c053e7d 23549a1 648acd4 23549a1 648acd4 23549a1 c053e7d 648acd4 c053e7d 23549a1 648acd4 c053e7d 648acd4 c053e7d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 |
from typing import Annotated
from sqlalchemy.orm import Session
from contextlib import asynccontextmanager
from starlette.middleware.cors import CORSMiddleware
from fastapi.security import OAuth2PasswordRequestForm
from fastapi import APIRouter, FastAPI, HTTPException, Depends
import torch
from ip_adapter import IPAdapter
from transformers import AutoModelForImageSegmentation
from diffusers import AutoPipelineForText2Image, DPMSolverMultistepScheduler, StableDiffusionInstructPix2PixPipeline, EulerAncestralDiscreteScheduler, StableDiffusionPipeline, DDIMScheduler, AutoencoderKL
from app.api import user
from app.api import prompt
from app.api import generate
from app.db import get_db, init_db
from app.config import ACCESS_TOKEN_EXPIRE_MINUTES, DEVICE
from app.security import authenticate_user, create_access_token, timedelta
@asynccontextmanager
async def lifespan(app: FastAPI):
base_model_path = "runwayml/stable-diffusion-v1-5"
vae_model_path = "stabilityai/sd-vae-ft-mse"
image_encoder_path = "models/image_encoder/"
ip_ckpt = "models/ip-adapter_sd15.bin"
ti_pipe = AutoPipelineForText2Image.from_pretrained(
'lykon/dreamshaper-xl-v2-turbo', torch_dtype=torch.float16, variant="fp16")
ti_pipe.to(DEVICE)
ti_pipe.scheduler = DPMSolverMultistepScheduler.from_config(
ti_pipe.scheduler.config)
ii_pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained(
"timbrooks/instruct-pix2pix", torch_dtype=torch.float16, safety_checker=None)
ii_pipe.to(DEVICE)
ii_pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(
ii_pipe.scheduler.config)
br_model = AutoModelForImageSegmentation.from_pretrained(
"briaai/RMBG-1.4", trust_remote_code=True)
br_model.to(DEVICE)
noise_scheduler = DDIMScheduler(
num_train_timesteps=1000,
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
clip_sample=False,
set_alpha_to_one=False,
steps_offset=1,
)
vae = AutoencoderKL.from_pretrained(vae_model_path).to(dtype=torch.float16)
pipe = StableDiffusionPipeline.from_pretrained(
base_model_path,
torch_dtype=torch.float16,
scheduler=noise_scheduler,
vae=vae,
feature_extractor=None,
safety_checker=None
)
iv_model = IPAdapter(pipe, image_encoder_path, ip_ckpt, DEVICE)
yield {'ti_pipe': ti_pipe, 'ii_pipe': ii_pipe, 'br_model': br_model, 'iv_model': iv_model}
del ti_pipe
del ii_pipe
del br_model
del vae
del pipe
del iv_model
app = FastAPI(lifespan=lifespan)
# app = FastAPI()
router = APIRouter()
origins = ["*"]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
init_db()
@app.get("/")
def read_root():
return {"Hello": "World"}
@app.post("/token")
async def login(form_data: Annotated[OAuth2PasswordRequestForm, Depends()], db: Annotated[Session, Depends(get_db)]):
user = authenticate_user(
db, form_data.username, form_data.password)
if not user:
raise HTTPException(
status_code=400, detail="Incorrect username or password")
access_token_expires = timedelta(
minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
access_token = create_access_token(
data={"sub": user.username}, expires_delta=access_token_expires)
return {"access_token": access_token, "token_type": "bearer"}
router.include_router(user.router, prefix="/users", tags=["users"])
router.include_router(prompt.router, prefix="/prompts", tags=["prompts"])
router.include_router(generate.router, prefix="/generate", tags=["generate"])
app.include_router(router)
|