Spaces:
Paused
Paused
Upload 4 files
Browse files- app/api/generate.py +236 -158
- app/api/prompt.py +45 -34
- app/api/user.py +69 -55
app/api/generate.py
CHANGED
@@ -1,158 +1,236 @@
|
|
1 |
-
from app.db import get_db
|
2 |
-
from app.config import DEVICE
|
3 |
-
from app.core import schemas, crud
|
4 |
-
from app.security import get_current_user
|
5 |
-
from app.core.schemas import TextImage, ImageImage, BackgroundRemoval, ImageVariations
|
6 |
-
|
7 |
-
import base64
|
8 |
-
from io import BytesIO
|
9 |
-
from sqlalchemy.orm import Session
|
10 |
-
from typing import Annotated, List
|
11 |
-
from fastapi import APIRouter, Depends, HTTPException, Request
|
12 |
-
|
13 |
-
import torch
|
14 |
-
import numpy as np
|
15 |
-
from PIL import Image
|
16 |
-
import torch.nn.functional as F
|
17 |
-
from torchvision.transforms.functional import normalize
|
18 |
-
|
19 |
-
|
20 |
-
router = APIRouter()
|
21 |
-
|
22 |
-
|
23 |
-
def decode_image(image):
|
24 |
-
return Image.open(BytesIO(base64.b64decode(image))).convert("RGB")
|
25 |
-
|
26 |
-
|
27 |
-
def encode_image(image):
|
28 |
-
bytes = BytesIO()
|
29 |
-
image.save(bytes, format="PNG")
|
30 |
-
return base64.b64encode(bytes.getvalue())
|
31 |
-
|
32 |
-
|
33 |
-
def create_prompt(
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
if
|
45 |
-
|
46 |
-
if
|
47 |
-
subject = f"{
|
48 |
-
if
|
49 |
-
subject = f"{subject}, {
|
50 |
-
if
|
51 |
-
subject = f"{subject}, {
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from app.db import get_db
|
2 |
+
from app.config import DEVICE
|
3 |
+
from app.core import schemas, crud
|
4 |
+
from app.security import get_current_user
|
5 |
+
from app.core.schemas import TextImage, ImageImage, BackgroundRemoval, ImageVariations
|
6 |
+
|
7 |
+
import base64
|
8 |
+
from io import BytesIO
|
9 |
+
from sqlalchemy.orm import Session
|
10 |
+
from typing import Annotated, List
|
11 |
+
from fastapi import APIRouter, Depends, HTTPException, Request
|
12 |
+
|
13 |
+
import torch
|
14 |
+
import numpy as np
|
15 |
+
from PIL import Image
|
16 |
+
import torch.nn.functional as F
|
17 |
+
from torchvision.transforms.functional import normalize
|
18 |
+
|
19 |
+
|
20 |
+
router = APIRouter()
|
21 |
+
|
22 |
+
|
23 |
+
def decode_image(image):
|
24 |
+
return Image.open(BytesIO(base64.b64decode(image))).convert("RGB")
|
25 |
+
|
26 |
+
|
27 |
+
def encode_image(image):
|
28 |
+
bytes = BytesIO()
|
29 |
+
image.save(bytes, format="PNG")
|
30 |
+
return base64.b64encode(bytes.getvalue())
|
31 |
+
|
32 |
+
|
33 |
+
def create_prompt(
|
34 |
+
subject,
|
35 |
+
medium,
|
36 |
+
style,
|
37 |
+
artist,
|
38 |
+
website,
|
39 |
+
resolution,
|
40 |
+
additional_details,
|
41 |
+
color,
|
42 |
+
lightning,
|
43 |
+
):
|
44 |
+
if not subject:
|
45 |
+
return None
|
46 |
+
if medium:
|
47 |
+
subject = f"{medium} of {subject}"
|
48 |
+
if style:
|
49 |
+
subject = f"{subject}, {style}"
|
50 |
+
if artist:
|
51 |
+
subject = f"{subject}, by {artist}"
|
52 |
+
if website:
|
53 |
+
subject = f"{subject}, {website}"
|
54 |
+
if resolution:
|
55 |
+
subject = f"{subject}, {resolution}"
|
56 |
+
if additional_details:
|
57 |
+
subject = f"{subject}, {additional_details}"
|
58 |
+
if color:
|
59 |
+
subject = f"{subject}, {color}"
|
60 |
+
if lightning:
|
61 |
+
subject = f"{subject}, {lightning}"
|
62 |
+
return subject
|
63 |
+
|
64 |
+
|
65 |
+
@router.post("/text-image/", response_model=str)
|
66 |
+
def text_image(
|
67 |
+
model: Request,
|
68 |
+
request: TextImage,
|
69 |
+
db: Annotated[Session, Depends(get_db)],
|
70 |
+
current_user: Annotated[schemas.User, Depends(get_current_user)],
|
71 |
+
):
|
72 |
+
if not current_user.is_active:
|
73 |
+
raise HTTPException(status_code=403, detail="Forbidden")
|
74 |
+
|
75 |
+
generator = torch.manual_seed(request.seed)
|
76 |
+
prompt = create_prompt(
|
77 |
+
request.prompt,
|
78 |
+
medium=request.medium,
|
79 |
+
style=request.style,
|
80 |
+
artist=request.artist,
|
81 |
+
website=request.website,
|
82 |
+
resolution=request.resolution,
|
83 |
+
additional_details=request.additional_details,
|
84 |
+
color=request.color,
|
85 |
+
lightning=request.lightning,
|
86 |
+
)
|
87 |
+
|
88 |
+
crud.create_prompt(db=db, user_id=current_user.user_id, prompt=prompt)
|
89 |
+
|
90 |
+
image = model.state.ti_pipe(
|
91 |
+
prompt,
|
92 |
+
num_inference_steps=request.num_inference_steps,
|
93 |
+
guidance_scale=request.guidance_scale,
|
94 |
+
generator=generator,
|
95 |
+
negative_prompt=request.negative_prompt,
|
96 |
+
).images[0]
|
97 |
+
|
98 |
+
return encode_image(image)
|
99 |
+
|
100 |
+
|
101 |
+
@router.post("/image-image/", response_model=str)
|
102 |
+
def image_image(
|
103 |
+
model: Request,
|
104 |
+
request: ImageImage,
|
105 |
+
db: Annotated[Session, Depends(get_db)],
|
106 |
+
current_user: Annotated[schemas.User, Depends(get_current_user)],
|
107 |
+
):
|
108 |
+
if not current_user.is_active:
|
109 |
+
raise HTTPException(status_code=403, detail="Forbidden")
|
110 |
+
|
111 |
+
generator = torch.manual_seed(request.seed)
|
112 |
+
prompt = create_prompt(
|
113 |
+
request.prompt,
|
114 |
+
medium=request.medium,
|
115 |
+
style=request.style,
|
116 |
+
artist=request.artist,
|
117 |
+
website=request.website,
|
118 |
+
resolution=request.resolution,
|
119 |
+
additional_details=request.additional_details,
|
120 |
+
color=request.color,
|
121 |
+
lightning=request.lightning,
|
122 |
+
)
|
123 |
+
image = decode_image(request.image)
|
124 |
+
|
125 |
+
crud.create_prompt(db=db, user_id=current_user.user_id, prompt=prompt)
|
126 |
+
|
127 |
+
image = model.state.ii_pipe(
|
128 |
+
prompt,
|
129 |
+
image=image,
|
130 |
+
num_inference_steps=request.num_inference_steps,
|
131 |
+
guidance_scale=request.guidance_scale,
|
132 |
+
image_guidance_scale=request.image_guidance_scale,
|
133 |
+
generator=generator,
|
134 |
+
negative_prompt=request.negative_prompt,
|
135 |
+
).images[0]
|
136 |
+
|
137 |
+
return encode_image(image)
|
138 |
+
|
139 |
+
|
140 |
+
@router.post("/background-removal/", response_model=str)
|
141 |
+
def background_removal(
|
142 |
+
model: Request,
|
143 |
+
request: BackgroundRemoval,
|
144 |
+
current_user: Annotated[schemas.User, Depends(get_current_user)],
|
145 |
+
):
|
146 |
+
if not current_user.is_active:
|
147 |
+
raise HTTPException(status_code=403, detail="Forbidden")
|
148 |
+
|
149 |
+
image = decode_image(request.image)
|
150 |
+
|
151 |
+
def preprocess_image(im: np.ndarray, model_input_size: list) -> torch.Tensor:
|
152 |
+
if len(im.shape) < 3:
|
153 |
+
im = im[:, :, np.newaxis]
|
154 |
+
# orig_im_size=im.shape[0:2]
|
155 |
+
im_tensor = torch.tensor(im, dtype=torch.float32).permute(2, 0, 1)
|
156 |
+
im_tensor = F.interpolate(
|
157 |
+
torch.unsqueeze(im_tensor, 0), size=model_input_size, mode="bilinear"
|
158 |
+
)
|
159 |
+
image = torch.divide(im_tensor, 255.0)
|
160 |
+
image = normalize(image, [0.5, 0.5, 0.5], [1.0, 1.0, 1.0])
|
161 |
+
return image
|
162 |
+
|
163 |
+
def postprocess_image(result: torch.Tensor, im_size: list) -> np.ndarray:
|
164 |
+
result = torch.squeeze(F.interpolate(result, size=im_size, mode="bilinear"), 0)
|
165 |
+
ma = torch.max(result)
|
166 |
+
mi = torch.min(result)
|
167 |
+
result = (result - mi) / (ma - mi)
|
168 |
+
im_array = (result * 255).permute(1, 2, 0).cpu().data.numpy().astype(np.uint8)
|
169 |
+
im_array = np.squeeze(im_array)
|
170 |
+
return im_array
|
171 |
+
|
172 |
+
# prepare input
|
173 |
+
model_input_size = [1024, 1024]
|
174 |
+
orig_im = np.array(image)
|
175 |
+
orig_im_size = orig_im.shape[0:2]
|
176 |
+
image = preprocess_image(orig_im, model_input_size).to(DEVICE)
|
177 |
+
|
178 |
+
# inference
|
179 |
+
result = model.state.br_model(image)
|
180 |
+
|
181 |
+
# post process
|
182 |
+
result_image = postprocess_image(result[0][0], orig_im_size)
|
183 |
+
|
184 |
+
# save result
|
185 |
+
pil_im = Image.fromarray(result_image)
|
186 |
+
no_bg_image = Image.new("RGBA", pil_im.size, (0, 0, 0, 0))
|
187 |
+
orig_image = Image.fromarray(orig_im)
|
188 |
+
no_bg_image.paste(orig_image, mask=pil_im)
|
189 |
+
|
190 |
+
return encode_image(no_bg_image)
|
191 |
+
|
192 |
+
|
193 |
+
@router.post("/image-variations/", response_model=List[str])
|
194 |
+
def image_variations(
|
195 |
+
model: Request,
|
196 |
+
request: ImageVariations,
|
197 |
+
db: Annotated[Session, Depends(get_db)],
|
198 |
+
current_user: Annotated[schemas.User, Depends(get_current_user)],
|
199 |
+
):
|
200 |
+
if not current_user.is_active:
|
201 |
+
raise HTTPException(status_code=403, detail="Forbidden")
|
202 |
+
|
203 |
+
prompt = create_prompt(
|
204 |
+
request.prompt,
|
205 |
+
medium=request.medium,
|
206 |
+
style=request.style,
|
207 |
+
artist=request.artist,
|
208 |
+
website=request.website,
|
209 |
+
resolution=request.resolution,
|
210 |
+
additional_details=request.additional_details,
|
211 |
+
color=request.color,
|
212 |
+
lightning=request.lightning,
|
213 |
+
)
|
214 |
+
image = decode_image(request.image)
|
215 |
+
image.resize((256, 256))
|
216 |
+
|
217 |
+
if prompt:
|
218 |
+
prompt = f"best quality, high quality, {prompt}"
|
219 |
+
crud.create_prompt(db=db, user_id=current_user.user_id, prompt=prompt)
|
220 |
+
else:
|
221 |
+
request.scale = 1.0
|
222 |
+
|
223 |
+
images = model.state.iv_model.generate(
|
224 |
+
pil_image=image,
|
225 |
+
num_samples=request.num_samples,
|
226 |
+
num_inference_steps=request.num_inference_steps,
|
227 |
+
seed=request.seed,
|
228 |
+
prompt=prompt,
|
229 |
+
negative_prompt=request.negative_prompt,
|
230 |
+
scale=request.scale,
|
231 |
+
guidance_scale=request.guidance_scale,
|
232 |
+
)
|
233 |
+
|
234 |
+
images = [encode_image(image) for image in images]
|
235 |
+
|
236 |
+
return images
|
app/api/prompt.py
CHANGED
@@ -1,34 +1,45 @@
|
|
1 |
-
from typing import Annotated, List
|
2 |
-
from sqlalchemy.orm import Session
|
3 |
-
from fastapi import APIRouter, Depends, HTTPException
|
4 |
-
|
5 |
-
from app.db import get_db
|
6 |
-
from app.core import schemas, crud
|
7 |
-
from app.security import get_current_user
|
8 |
-
|
9 |
-
|
10 |
-
router = APIRouter()
|
11 |
-
|
12 |
-
|
13 |
-
@router.
|
14 |
-
def get_all_prompts(
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Annotated, List
|
2 |
+
from sqlalchemy.orm import Session
|
3 |
+
from fastapi import APIRouter, Depends, HTTPException
|
4 |
+
|
5 |
+
from app.db import get_db
|
6 |
+
from app.core import schemas, crud
|
7 |
+
from app.security import get_current_user
|
8 |
+
|
9 |
+
|
10 |
+
router = APIRouter()
|
11 |
+
|
12 |
+
|
13 |
+
@router.get("/get-all-prompts/", response_model=List[schemas.Prompt])
|
14 |
+
def get_all_prompts(
|
15 |
+
db: Annotated[Session, Depends(get_db)],
|
16 |
+
current_user: Annotated[schemas.User, Depends(get_current_user)],
|
17 |
+
):
|
18 |
+
if not current_user.is_superuser:
|
19 |
+
raise HTTPException(status_code=403, detail="Forbidden")
|
20 |
+
|
21 |
+
return crud.get_all_prompts(db=db)
|
22 |
+
|
23 |
+
|
24 |
+
@router.get("/get-prompt_by_user_id/{user_id}/", response_model=List[schemas.Prompt])
|
25 |
+
def get_prompt_by_user_id(
|
26 |
+
user_id: int,
|
27 |
+
db: Annotated[Session, Depends(get_db)],
|
28 |
+
current_user: Annotated[schemas.User, Depends(get_current_user)],
|
29 |
+
):
|
30 |
+
if not current_user.is_superuser:
|
31 |
+
raise HTTPException(status_code=403, detail="Forbidden")
|
32 |
+
|
33 |
+
return crud.get_prompt_by_user_id(user_id=user_id, db=db)
|
34 |
+
|
35 |
+
|
36 |
+
# @router.post("/create-prompt/", response_model=schemas.Prompt)
|
37 |
+
# def create_prompt(
|
38 |
+
# prompt: str,
|
39 |
+
# db: Annotated[Session, Depends(get_db)],
|
40 |
+
# current_user: Annotated[schemas.User, Depends(get_current_user)],
|
41 |
+
# ):
|
42 |
+
# if not current_user.is_superuser:
|
43 |
+
# raise HTTPException(status_code=403, detail="Forbidden")
|
44 |
+
|
45 |
+
# return crud.create_prompt(prompt=prompt, db=db)
|
app/api/user.py
CHANGED
@@ -1,55 +1,69 @@
|
|
1 |
-
from typing import Annotated, List
|
2 |
-
from sqlalchemy.orm import Session
|
3 |
-
from fastapi import APIRouter, Depends, HTTPException
|
4 |
-
|
5 |
-
from app.db import get_db
|
6 |
-
from app.core import schemas, crud
|
7 |
-
from app.security import get_current_user
|
8 |
-
|
9 |
-
|
10 |
-
router = APIRouter()
|
11 |
-
|
12 |
-
|
13 |
-
@router.post("/create-user/", response_model=schemas.User)
|
14 |
-
def create_user(
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
if
|
20 |
-
raise HTTPException(
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Annotated, List
|
2 |
+
from sqlalchemy.orm import Session
|
3 |
+
from fastapi import APIRouter, Depends, HTTPException
|
4 |
+
|
5 |
+
from app.db import get_db
|
6 |
+
from app.core import schemas, crud
|
7 |
+
from app.security import get_current_user
|
8 |
+
|
9 |
+
|
10 |
+
router = APIRouter()
|
11 |
+
|
12 |
+
|
13 |
+
@router.post("/create-user/", response_model=schemas.User)
|
14 |
+
def create_user(
|
15 |
+
user: schemas.UserCreate,
|
16 |
+
db: Annotated[Session, Depends(get_db)],
|
17 |
+
current_user: Annotated[schemas.User, Depends(get_current_user)],
|
18 |
+
):
|
19 |
+
if not current_user.is_superuser:
|
20 |
+
raise HTTPException(status_code=403, detail="Forbidden")
|
21 |
+
|
22 |
+
user_exists = crud.get_user_by_username(username=user.username, db=db)
|
23 |
+
if user_exists:
|
24 |
+
raise HTTPException(status_code=400, detail="Username already registered")
|
25 |
+
|
26 |
+
return crud.create_user(user=user, db=db)
|
27 |
+
|
28 |
+
|
29 |
+
@router.put("/update-user/", response_model=schemas.User)
|
30 |
+
def update_user(
|
31 |
+
user: schemas.UserUpdate,
|
32 |
+
db: Annotated[Session, Depends(get_db)],
|
33 |
+
current_user: Annotated[schemas.User, Depends(get_current_user)],
|
34 |
+
):
|
35 |
+
if not current_user.is_superuser:
|
36 |
+
raise HTTPException(status_code=403, detail="Forbidden")
|
37 |
+
|
38 |
+
user_exists = crud.get_user_by_user_id(user_id=user.user_id, db=db)
|
39 |
+
if not user_exists:
|
40 |
+
raise HTTPException(status_code=404, detail="User not found")
|
41 |
+
|
42 |
+
return crud.update_user(user=user, db=db)
|
43 |
+
|
44 |
+
|
45 |
+
@router.get("/get-all-users/", response_model=List[schemas.User])
|
46 |
+
def get_all_users(
|
47 |
+
db: Annotated[Session, Depends(get_db)],
|
48 |
+
current_user: Annotated[schemas.User, Depends(get_current_user)],
|
49 |
+
):
|
50 |
+
if not current_user.is_superuser:
|
51 |
+
raise HTTPException(status_code=403, detail="Forbidden")
|
52 |
+
|
53 |
+
return crud.get_all_users(db=db)
|
54 |
+
|
55 |
+
|
56 |
+
@router.get("/get-user_by_user_id/{user_id}/", response_model=schemas.User)
|
57 |
+
def get_user_by_user_id(
|
58 |
+
user_id: int,
|
59 |
+
db: Annotated[Session, Depends(get_db)],
|
60 |
+
current_user: Annotated[schemas.User, Depends(get_current_user)],
|
61 |
+
):
|
62 |
+
if not current_user.is_superuser:
|
63 |
+
raise HTTPException(status_code=403, detail="Forbidden")
|
64 |
+
|
65 |
+
user = crud.get_user_by_user_id(user_id=user_id, db=db)
|
66 |
+
if user is None:
|
67 |
+
raise HTTPException(status_code=404, detail="User not found")
|
68 |
+
|
69 |
+
return user
|