khawir commited on
Commit
e6d2554
1 Parent(s): b65fd00

Update app/api/generate.py

Browse files
Files changed (1) hide show
  1. app/api/generate.py +238 -236
app/api/generate.py CHANGED
@@ -1,236 +1,238 @@
1
- from app.db import get_db
2
- from app.config import DEVICE
3
- from app.core import schemas, crud
4
- from app.security import get_current_user
5
- from app.core.schemas import TextImage, ImageImage, BackgroundRemoval, ImageVariations
6
-
7
- import base64
8
- from io import BytesIO
9
- from sqlalchemy.orm import Session
10
- from typing import Annotated, List
11
- from fastapi import APIRouter, Depends, HTTPException, Request
12
-
13
- import torch
14
- import numpy as np
15
- from PIL import Image
16
- import torch.nn.functional as F
17
- from torchvision.transforms.functional import normalize
18
-
19
-
20
- router = APIRouter()
21
-
22
-
23
- def decode_image(image):
24
- return Image.open(BytesIO(base64.b64decode(image))).convert("RGB")
25
-
26
-
27
- def encode_image(image):
28
- bytes = BytesIO()
29
- image.save(bytes, format="PNG")
30
- return base64.b64encode(bytes.getvalue())
31
-
32
-
33
- def create_prompt(
34
- subject,
35
- medium,
36
- style,
37
- artist,
38
- website,
39
- resolution,
40
- additional_details,
41
- color,
42
- lightning,
43
- ):
44
- if not subject:
45
- return None
46
- if medium:
47
- subject = f"{medium} of {subject}"
48
- if style:
49
- subject = f"{subject}, {style}"
50
- if artist:
51
- subject = f"{subject}, by {artist}"
52
- if website:
53
- subject = f"{subject}, {website}"
54
- if resolution:
55
- subject = f"{subject}, {resolution}"
56
- if additional_details:
57
- subject = f"{subject}, {additional_details}"
58
- if color:
59
- subject = f"{subject}, {color}"
60
- if lightning:
61
- subject = f"{subject}, {lightning}"
62
- return subject
63
-
64
-
65
- @router.post("/text-image/", response_model=str)
66
- def text_image(
67
- model: Request,
68
- request: TextImage,
69
- db: Annotated[Session, Depends(get_db)],
70
- current_user: Annotated[schemas.User, Depends(get_current_user)],
71
- ):
72
- if not current_user.is_active:
73
- raise HTTPException(status_code=403, detail="Forbidden")
74
-
75
- generator = torch.manual_seed(request.seed)
76
- prompt = create_prompt(
77
- request.prompt,
78
- medium=request.medium,
79
- style=request.style,
80
- artist=request.artist,
81
- website=request.website,
82
- resolution=request.resolution,
83
- additional_details=request.additional_details,
84
- color=request.color,
85
- lightning=request.lightning,
86
- )
87
-
88
- crud.create_prompt(db=db, user_id=current_user.user_id, prompt=prompt)
89
-
90
- image = model.state.ti_pipe(
91
- prompt,
92
- num_inference_steps=request.num_inference_steps,
93
- guidance_scale=request.guidance_scale,
94
- generator=generator,
95
- negative_prompt=request.negative_prompt,
96
- ).images[0]
97
-
98
- return encode_image(image)
99
-
100
-
101
- @router.post("/image-image/", response_model=str)
102
- def image_image(
103
- model: Request,
104
- request: ImageImage,
105
- db: Annotated[Session, Depends(get_db)],
106
- current_user: Annotated[schemas.User, Depends(get_current_user)],
107
- ):
108
- if not current_user.is_active:
109
- raise HTTPException(status_code=403, detail="Forbidden")
110
-
111
- generator = torch.manual_seed(request.seed)
112
- prompt = create_prompt(
113
- request.prompt,
114
- medium=request.medium,
115
- style=request.style,
116
- artist=request.artist,
117
- website=request.website,
118
- resolution=request.resolution,
119
- additional_details=request.additional_details,
120
- color=request.color,
121
- lightning=request.lightning,
122
- )
123
- image = decode_image(request.image)
124
-
125
- crud.create_prompt(db=db, user_id=current_user.user_id, prompt=prompt)
126
-
127
- image = model.state.ii_pipe(
128
- prompt,
129
- image=image,
130
- num_inference_steps=request.num_inference_steps,
131
- guidance_scale=request.guidance_scale,
132
- image_guidance_scale=request.image_guidance_scale,
133
- generator=generator,
134
- negative_prompt=request.negative_prompt,
135
- ).images[0]
136
-
137
- return encode_image(image)
138
-
139
-
140
- @router.post("/background-removal/", response_model=str)
141
- def background_removal(
142
- model: Request,
143
- request: BackgroundRemoval,
144
- current_user: Annotated[schemas.User, Depends(get_current_user)],
145
- ):
146
- if not current_user.is_active:
147
- raise HTTPException(status_code=403, detail="Forbidden")
148
-
149
- image = decode_image(request.image)
150
-
151
- def preprocess_image(im: np.ndarray, model_input_size: list) -> torch.Tensor:
152
- if len(im.shape) < 3:
153
- im = im[:, :, np.newaxis]
154
- # orig_im_size=im.shape[0:2]
155
- im_tensor = torch.tensor(im, dtype=torch.float32).permute(2, 0, 1)
156
- im_tensor = F.interpolate(
157
- torch.unsqueeze(im_tensor, 0), size=model_input_size, mode="bilinear"
158
- )
159
- image = torch.divide(im_tensor, 255.0)
160
- image = normalize(image, [0.5, 0.5, 0.5], [1.0, 1.0, 1.0])
161
- return image
162
-
163
- def postprocess_image(result: torch.Tensor, im_size: list) -> np.ndarray:
164
- result = torch.squeeze(F.interpolate(result, size=im_size, mode="bilinear"), 0)
165
- ma = torch.max(result)
166
- mi = torch.min(result)
167
- result = (result - mi) / (ma - mi)
168
- im_array = (result * 255).permute(1, 2, 0).cpu().data.numpy().astype(np.uint8)
169
- im_array = np.squeeze(im_array)
170
- return im_array
171
-
172
- # prepare input
173
- model_input_size = [1024, 1024]
174
- orig_im = np.array(image)
175
- orig_im_size = orig_im.shape[0:2]
176
- image = preprocess_image(orig_im, model_input_size).to(DEVICE)
177
-
178
- # inference
179
- result = model.state.br_model(image)
180
-
181
- # post process
182
- result_image = postprocess_image(result[0][0], orig_im_size)
183
-
184
- # save result
185
- pil_im = Image.fromarray(result_image)
186
- no_bg_image = Image.new("RGBA", pil_im.size, (0, 0, 0, 0))
187
- orig_image = Image.fromarray(orig_im)
188
- no_bg_image.paste(orig_image, mask=pil_im)
189
-
190
- return encode_image(no_bg_image)
191
-
192
-
193
- @router.post("/image-variations/", response_model=List[str])
194
- def image_variations(
195
- model: Request,
196
- request: ImageVariations,
197
- db: Annotated[Session, Depends(get_db)],
198
- current_user: Annotated[schemas.User, Depends(get_current_user)],
199
- ):
200
- if not current_user.is_active:
201
- raise HTTPException(status_code=403, detail="Forbidden")
202
-
203
- prompt = create_prompt(
204
- request.prompt,
205
- medium=request.medium,
206
- style=request.style,
207
- artist=request.artist,
208
- website=request.website,
209
- resolution=request.resolution,
210
- additional_details=request.additional_details,
211
- color=request.color,
212
- lightning=request.lightning,
213
- )
214
- image = decode_image(request.image)
215
- image.resize((256, 256))
216
-
217
- if prompt:
218
- prompt = f"best quality, high quality, {prompt}"
219
- crud.create_prompt(db=db, user_id=current_user.user_id, prompt=prompt)
220
- else:
221
- request.scale = 1.0
222
-
223
- images = model.state.iv_model.generate(
224
- pil_image=image,
225
- num_samples=request.num_samples,
226
- num_inference_steps=request.num_inference_steps,
227
- seed=request.seed,
228
- prompt=prompt,
229
- negative_prompt=request.negative_prompt,
230
- scale=request.scale,
231
- guidance_scale=request.guidance_scale,
232
- )
233
-
234
- images = [encode_image(image) for image in images]
235
-
236
- return images
 
 
 
1
+ from app.db import get_db
2
+ from app.config import DEVICE
3
+ from app.core import schemas, crud
4
+ from app.security import get_current_user
5
+ from app.core.schemas import TextImage, ImageImage, BackgroundRemoval, ImageVariations
6
+
7
+ import base64
8
+ from io import BytesIO
9
+ from sqlalchemy.orm import Session
10
+ from typing import Annotated, List
11
+ from fastapi import APIRouter, Depends, HTTPException, Request
12
+
13
+ import torch
14
+ import numpy as np
15
+ from PIL import Image
16
+ import torch.nn.functional as F
17
+ from torchvision.transforms.functional import normalize
18
+
19
+
20
+ router = APIRouter()
21
+
22
+
23
+ def decode_image(image):
24
+ return Image.open(BytesIO(base64.b64decode(image))).convert("RGB")
25
+
26
+
27
+ def encode_image(image):
28
+ bytes = BytesIO()
29
+ image.save(bytes, format="PNG")
30
+ return base64.b64encode(bytes.getvalue())
31
+
32
+
33
+ def create_prompt(
34
+ subject,
35
+ medium,
36
+ style,
37
+ artist,
38
+ website,
39
+ resolution,
40
+ additional_details,
41
+ color,
42
+ lightning,
43
+ ):
44
+ if not subject:
45
+ return None
46
+ if medium:
47
+ subject = f"{medium} of {subject}"
48
+ if style:
49
+ subject = f"{subject}, {style}"
50
+ if artist:
51
+ subject = f"{subject}, by {artist}"
52
+ if website:
53
+ subject = f"{subject}, {website}"
54
+ if resolution:
55
+ subject = f"{subject}, {resolution}"
56
+ if additional_details:
57
+ subject = f"{subject}, {additional_details}"
58
+ if color:
59
+ subject = f"{subject}, {color}"
60
+ if lightning:
61
+ subject = f"{subject}, {lightning}"
62
+ return subject
63
+
64
+
65
+ @router.post("/text-image/", response_model=str)
66
+ def text_image(
67
+ model: Request,
68
+ request: TextImage,
69
+ db: Annotated[Session, Depends(get_db)],
70
+ current_user: Annotated[schemas.User, Depends(get_current_user)],
71
+ ):
72
+ if not current_user.is_active:
73
+ raise HTTPException(status_code=403, detail="Forbidden")
74
+
75
+ generator = torch.manual_seed(request.seed)
76
+ prompt = create_prompt(
77
+ request.prompt,
78
+ medium=request.medium,
79
+ style=request.style,
80
+ artist=request.artist,
81
+ website=request.website,
82
+ resolution=request.resolution,
83
+ additional_details=request.additional_details,
84
+ color=request.color,
85
+ lightning=request.lightning,
86
+ )
87
+
88
+ crud.create_prompt(db=db, user_id=current_user.user_id, prompt=prompt)
89
+
90
+ with torch.inference_mode():
91
+ image = model.state.ti_pipe(
92
+ prompt,
93
+ num_inference_steps=request.num_inference_steps,
94
+ guidance_scale=request.guidance_scale,
95
+ generator=generator,
96
+ negative_prompt=request.negative_prompt,
97
+ ).images[0]
98
+
99
+ return encode_image(image)
100
+
101
+
102
+ @router.post("/image-image/", response_model=str)
103
+ def image_image(
104
+ model: Request,
105
+ request: ImageImage,
106
+ db: Annotated[Session, Depends(get_db)],
107
+ current_user: Annotated[schemas.User, Depends(get_current_user)],
108
+ ):
109
+ if not current_user.is_active:
110
+ raise HTTPException(status_code=403, detail="Forbidden")
111
+
112
+ generator = torch.manual_seed(request.seed)
113
+ prompt = create_prompt(
114
+ request.prompt,
115
+ medium=request.medium,
116
+ style=request.style,
117
+ artist=request.artist,
118
+ website=request.website,
119
+ resolution=request.resolution,
120
+ additional_details=request.additional_details,
121
+ color=request.color,
122
+ lightning=request.lightning,
123
+ )
124
+ image = decode_image(request.image)
125
+
126
+ crud.create_prompt(db=db, user_id=current_user.user_id, prompt=prompt)
127
+
128
+ with torch.inference_mode():
129
+ image = model.state.ii_pipe(
130
+ prompt,
131
+ image=image,
132
+ num_inference_steps=request.num_inference_steps,
133
+ guidance_scale=request.guidance_scale,
134
+ image_guidance_scale=request.image_guidance_scale,
135
+ generator=generator,
136
+ negative_prompt=request.negative_prompt,
137
+ ).images[0]
138
+
139
+ return encode_image(image)
140
+
141
+
142
+ @router.post("/background-removal/", response_model=str)
143
+ def background_removal(
144
+ model: Request,
145
+ request: BackgroundRemoval,
146
+ current_user: Annotated[schemas.User, Depends(get_current_user)],
147
+ ):
148
+ if not current_user.is_active:
149
+ raise HTTPException(status_code=403, detail="Forbidden")
150
+
151
+ image = decode_image(request.image)
152
+
153
+ def preprocess_image(im: np.ndarray, model_input_size: list) -> torch.Tensor:
154
+ if len(im.shape) < 3:
155
+ im = im[:, :, np.newaxis]
156
+ # orig_im_size=im.shape[0:2]
157
+ im_tensor = torch.tensor(im, dtype=torch.float32).permute(2, 0, 1)
158
+ im_tensor = F.interpolate(
159
+ torch.unsqueeze(im_tensor, 0), size=model_input_size, mode="bilinear"
160
+ )
161
+ image = torch.divide(im_tensor, 255.0)
162
+ image = normalize(image, [0.5, 0.5, 0.5], [1.0, 1.0, 1.0])
163
+ return image
164
+
165
+ def postprocess_image(result: torch.Tensor, im_size: list) -> np.ndarray:
166
+ result = torch.squeeze(F.interpolate(result, size=im_size, mode="bilinear"), 0)
167
+ ma = torch.max(result)
168
+ mi = torch.min(result)
169
+ result = (result - mi) / (ma - mi)
170
+ im_array = (result * 255).permute(1, 2, 0).cpu().data.numpy().astype(np.uint8)
171
+ im_array = np.squeeze(im_array)
172
+ return im_array
173
+
174
+ # prepare input
175
+ model_input_size = [1024, 1024]
176
+ orig_im = np.array(image)
177
+ orig_im_size = orig_im.shape[0:2]
178
+ image = preprocess_image(orig_im, model_input_size).to(DEVICE)
179
+
180
+ # inference
181
+ result = model.state.br_model(image)
182
+
183
+ # post process
184
+ result_image = postprocess_image(result[0][0], orig_im_size)
185
+
186
+ # save result
187
+ pil_im = Image.fromarray(result_image)
188
+ no_bg_image = Image.new("RGBA", pil_im.size, (0, 0, 0, 0))
189
+ orig_image = Image.fromarray(orig_im)
190
+ no_bg_image.paste(orig_image, mask=pil_im)
191
+
192
+ return encode_image(no_bg_image)
193
+
194
+
195
+ @router.post("/image-variations/", response_model=List[str])
196
+ def image_variations(
197
+ model: Request,
198
+ request: ImageVariations,
199
+ db: Annotated[Session, Depends(get_db)],
200
+ current_user: Annotated[schemas.User, Depends(get_current_user)],
201
+ ):
202
+ if not current_user.is_active:
203
+ raise HTTPException(status_code=403, detail="Forbidden")
204
+
205
+ prompt = create_prompt(
206
+ request.prompt,
207
+ medium=request.medium,
208
+ style=request.style,
209
+ artist=request.artist,
210
+ website=request.website,
211
+ resolution=request.resolution,
212
+ additional_details=request.additional_details,
213
+ color=request.color,
214
+ lightning=request.lightning,
215
+ )
216
+ image = decode_image(request.image)
217
+ image.resize((256, 256))
218
+
219
+ if prompt:
220
+ prompt = f"best quality, high quality, {prompt}"
221
+ crud.create_prompt(db=db, user_id=current_user.user_id, prompt=prompt)
222
+ else:
223
+ request.scale = 1.0
224
+
225
+ images = model.state.iv_model.generate(
226
+ pil_image=image,
227
+ num_samples=request.num_samples,
228
+ num_inference_steps=request.num_inference_steps,
229
+ seed=request.seed,
230
+ prompt=prompt,
231
+ negative_prompt=request.negative_prompt,
232
+ scale=request.scale,
233
+ guidance_scale=request.guidance_scale,
234
+ )
235
+
236
+ images = [encode_image(image) for image in images]
237
+
238
+ return images