File size: 3,755 Bytes
c053e7d
 
 
 
 
 
 
 
23549a1
c053e7d
23549a1
c053e7d
 
 
 
0e1b8df
c053e7d
 
 
 
 
 
23549a1
 
 
 
c053e7d
6de967f
 
c053e7d
6de967f
 
c053e7d
 
 
 
 
 
 
 
 
 
 
23549a1
 
 
 
 
 
 
 
 
 
 
648acd4
 
23549a1
 
 
 
648acd4
23549a1
c053e7d
648acd4
c053e7d
 
 
 
23549a1
 
648acd4
c053e7d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
648acd4
 
 
c053e7d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
from typing import Annotated
from sqlalchemy.orm import Session
from contextlib import asynccontextmanager
from starlette.middleware.cors import CORSMiddleware
from fastapi.security import OAuth2PasswordRequestForm
from fastapi import APIRouter, FastAPI, HTTPException, Depends

import torch
from ip_adapter import IPAdapter
from transformers import AutoModelForImageSegmentation
from diffusers import AutoPipelineForText2Image, DPMSolverMultistepScheduler, StableDiffusionInstructPix2PixPipeline, EulerAncestralDiscreteScheduler, StableDiffusionPipeline, DDIMScheduler, AutoencoderKL

from app.api import user
from app.api import prompt
from app.api import generate
from app.db import get_db, init_db
from app.config import ACCESS_TOKEN_EXPIRE_MINUTES, DEVICE
from app.security import authenticate_user, create_access_token, timedelta


@asynccontextmanager
async def lifespan(app: FastAPI):
    base_model_path = "runwayml/stable-diffusion-v1-5"
    vae_model_path = "stabilityai/sd-vae-ft-mse"
    image_encoder_path = "models/image_encoder/"
    ip_ckpt = "models/ip-adapter_sd15.bin"

    ti_pipe = AutoPipelineForText2Image.from_pretrained(
        'lykon/dreamshaper-xl-v2-turbo', torch_dtype=torch.float16, variant="fp16")
    ti_pipe.to(DEVICE)
    ti_pipe.scheduler = DPMSolverMultistepScheduler.from_config(
        ti_pipe.scheduler.config)

    ii_pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained(
        "timbrooks/instruct-pix2pix", torch_dtype=torch.float16, safety_checker=None)
    ii_pipe.to(DEVICE)
    ii_pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(
        ii_pipe.scheduler.config)

    br_model = AutoModelForImageSegmentation.from_pretrained(
        "briaai/RMBG-1.4", trust_remote_code=True)
    br_model.to(DEVICE)

    noise_scheduler = DDIMScheduler(
        num_train_timesteps=1000,
        beta_start=0.00085,
        beta_end=0.012,
        beta_schedule="scaled_linear",
        clip_sample=False,
        set_alpha_to_one=False,
        steps_offset=1,
    )
    vae = AutoencoderKL.from_pretrained(vae_model_path).to(dtype=torch.float16)
    pipe = StableDiffusionPipeline.from_pretrained(
        base_model_path,
        torch_dtype=torch.float16,
        scheduler=noise_scheduler,
        vae=vae,
        feature_extractor=None,
        safety_checker=None
    )
    iv_model = IPAdapter(pipe, image_encoder_path, ip_ckpt, DEVICE)

    yield {'ti_pipe': ti_pipe, 'ii_pipe': ii_pipe, 'br_model': br_model, 'iv_model': iv_model}

    del ti_pipe
    del ii_pipe
    del br_model
    del vae
    del pipe
    del iv_model


app = FastAPI(lifespan=lifespan)
# app = FastAPI()

router = APIRouter()

origins = ["*"]

app.add_middleware(
    CORSMiddleware,
    allow_origins=origins,
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

init_db()


@app.get("/")
def read_root():
    return {"Hello": "World"}


@app.post("/token")
async def login(form_data: Annotated[OAuth2PasswordRequestForm, Depends()], db: Annotated[Session, Depends(get_db)]):
    user = authenticate_user(
        db, form_data.username, form_data.password)
    if not user:
        raise HTTPException(
            status_code=400, detail="Incorrect username or password")
    access_token_expires = timedelta(
        minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
    access_token = create_access_token(
        data={"sub": user.username}, expires_delta=access_token_expires)
    return {"access_token": access_token, "token_type": "bearer"}


router.include_router(user.router, prefix="/users", tags=["users"])
router.include_router(prompt.router, prefix="/prompts", tags=["prompts"])
router.include_router(generate.router, prefix="/generate", tags=["generate"])
app.include_router(router)