File size: 3,952 Bytes
c053e7d
 
 
 
 
 
 
 
23549a1
c053e7d
48c05f8
 
 
 
 
 
 
 
 
c053e7d
 
 
 
0e1b8df
c053e7d
 
 
 
 
 
23549a1
 
 
 
c053e7d
6de967f
48c05f8
 
6de967f
48c05f8
 
81feb39
 
c053e7d
 
48c05f8
 
c053e7d
48c05f8
 
81feb39
 
c053e7d
 
48c05f8
 
c053e7d
 
23549a1
 
 
 
 
 
 
 
 
 
 
648acd4
 
23549a1
 
 
48c05f8
648acd4
23549a1
c053e7d
48c05f8
 
 
 
 
 
c053e7d
 
 
 
23549a1
 
648acd4
c053e7d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48c05f8
 
 
 
 
c053e7d
48c05f8
 
c053e7d
48c05f8
 
c053e7d
 
 
648acd4
 
 
c053e7d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
from typing import Annotated
from sqlalchemy.orm import Session
from contextlib import asynccontextmanager
from starlette.middleware.cors import CORSMiddleware
from fastapi.security import OAuth2PasswordRequestForm
from fastapi import APIRouter, FastAPI, HTTPException, Depends

import torch
from ip_adapter import IPAdapter
from transformers import AutoModelForImageSegmentation
from diffusers import (
    AutoPipelineForText2Image,
    DPMSolverMultistepScheduler,
    StableDiffusionInstructPix2PixPipeline,
    EulerAncestralDiscreteScheduler,
    StableDiffusionPipeline,
    DDIMScheduler,
    AutoencoderKL,
)

from app.api import user
from app.api import prompt
from app.api import generate
from app.db import get_db, init_db
from app.config import ACCESS_TOKEN_EXPIRE_MINUTES, DEVICE
from app.security import authenticate_user, create_access_token, timedelta


@asynccontextmanager
async def lifespan(app: FastAPI):
    base_model_path = "runwayml/stable-diffusion-v1-5"
    vae_model_path = "stabilityai/sd-vae-ft-mse"
    image_encoder_path = "models/image_encoder/"
    ip_ckpt = "models/ip-adapter_sd15.bin"

    ti_pipe = AutoPipelineForText2Image.from_pretrained(
        "lykon/dreamshaper-xl-v2-turbo", torch_dtype=torch.float16, variant="fp16"
    )
    ti_pipe.scheduler = DPMSolverMultistepScheduler.from_config(
        ti_pipe.scheduler.config
    )
    ti_pipe.to(DEVICE)
    ti_pipe.enable_xformers_memory_efficient_attention()

    ii_pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained(
        "timbrooks/instruct-pix2pix", torch_dtype=torch.float16, safety_checker=None
    )
    ii_pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(
        ii_pipe.scheduler.config
    )
    ii_pipe.to(DEVICE)
    ii_pipe.enable_xformers_memory_efficient_attention()

    br_model = AutoModelForImageSegmentation.from_pretrained(
        "briaai/RMBG-1.4", trust_remote_code=True
    )
    br_model.to(DEVICE)

    noise_scheduler = DDIMScheduler(
        num_train_timesteps=1000,
        beta_start=0.00085,
        beta_end=0.012,
        beta_schedule="scaled_linear",
        clip_sample=False,
        set_alpha_to_one=False,
        steps_offset=1,
    )
    vae = AutoencoderKL.from_pretrained(vae_model_path).to(dtype=torch.float16)
    pipe = StableDiffusionPipeline.from_pretrained(
        base_model_path,
        torch_dtype=torch.float16,
        scheduler=noise_scheduler,
        vae=vae,
        feature_extractor=None,
        safety_checker=None,
    )
    iv_model = IPAdapter(pipe, image_encoder_path, ip_ckpt, DEVICE)

    yield {
        "ti_pipe": ti_pipe,
        "ii_pipe": ii_pipe,
        "br_model": br_model,
        "iv_model": iv_model,
    }

    del ti_pipe
    del ii_pipe
    del br_model
    del vae
    del pipe
    del iv_model


app = FastAPI(lifespan=lifespan)
# app = FastAPI()

router = APIRouter()

origins = ["*"]

app.add_middleware(
    CORSMiddleware,
    allow_origins=origins,
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

init_db()


@app.get("/")
def read_root():
    return {"Hello": "World"}


@app.post("/token")
async def login(
    form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
    db: Annotated[Session, Depends(get_db)],
):
    user = authenticate_user(db, form_data.username, form_data.password)
    if not user:
        raise HTTPException(status_code=400, detail="Incorrect username or password")
    access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
    access_token = create_access_token(
        data={"sub": user.username}, expires_delta=access_token_expires
    )
    return {"access_token": access_token, "token_type": "bearer"}


router.include_router(user.router, prefix="/users", tags=["users"])
router.include_router(prompt.router, prefix="/prompts", tags=["prompts"])
router.include_router(generate.router, prefix="/generate", tags=["generate"])
app.include_router(router)