3v324v23 commited on
Commit
ecc9585
1 Parent(s): 04e062b
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. README.md +6 -5
  2. app.py +311 -0
  3. css/0.png +0 -0
  4. css/style.css +59 -0
  5. image_gallery/00.png +0 -0
  6. image_gallery/01.png +0 -0
  7. image_gallery/02.png +0 -0
  8. image_gallery/a001.jpg +0 -0
  9. image_gallery/a002.jpg +0 -0
  10. image_gallery/a003.jpg +0 -0
  11. image_gallery/a004.jpg +0 -0
  12. image_gallery/a005.jpg +0 -0
  13. image_gallery/a006.jpg +0 -0
  14. image_gallery/a007.jpg +0 -0
  15. image_gallery/a009.jpg +0 -0
  16. image_gallery/bg_001.jpg +0 -0
  17. image_gallery/bg_002.jpg +0 -0
  18. image_gallery/bg_003.jpg +0 -0
  19. image_gallery/bg_004.jpg +0 -0
  20. image_gallery/bg_005.jpg +0 -0
  21. image_gallery/bg_006.jpg +0 -0
  22. image_gallery/bg_007.jpg +0 -0
  23. image_gallery/bg_008.jpg +0 -0
  24. image_gallery/bg_009.jpg +0 -0
  25. image_gallery/bg_010.jpg +0 -0
  26. image_gallery/bg_012.jpg +0 -0
  27. imgs/000.jpg +0 -0
  28. imgs/001.jpg +0 -0
  29. imgs/002.png +0 -0
  30. imgs/002_bg.png +0 -0
  31. imgs/003.png +0 -0
  32. imgs/003_bg.jpg +0 -0
  33. imgs/bg_gen/base_imgs/1cdb9b1e6daea6a1b85236595d3e43d6.png +0 -0
  34. imgs/bg_gen/base_imgs/IMG_2941.png +0 -0
  35. imgs/bg_gen/base_imgs/b2b1ed243364473e49d2e478e4f24413.png +0 -0
  36. imgs/bg_gen/ref_imgs/df9a93ac2bca12696a9166182c4bf02ad9679aa5.jpg +0 -0
  37. models/DOWNLOAD_MODEL_HERE.txt +2 -0
  38. models/sam_vit_h_4b8939.pth +3 -0
  39. requirements.txt +23 -0
  40. sdxl.txt +10 -0
  41. src/__init__.py +0 -0
  42. src/__pycache__/__init__.cpython-38.pyc +0 -0
  43. src/__pycache__/background_generation.cpython-38.pyc +0 -0
  44. src/__pycache__/log.cpython-38.pyc +0 -0
  45. src/__pycache__/person_detect.cpython-38.pyc +0 -0
  46. src/__pycache__/util.cpython-38.pyc +0 -0
  47. src/__pycache__/virtualmodel.cpython-38.pyc +0 -0
  48. src/background_generation.py +76 -0
  49. src/log.py +18 -0
  50. src/person_detect.py +39 -0
README.md CHANGED
@@ -1,12 +1,13 @@
1
  ---
2
- title: SAM SDXL Inpainting
3
- emoji: 🦀
4
- colorFrom: red
5
- colorTo: gray
6
  sdk: gradio
7
- sdk_version: 4.14.0
8
  app_file: app.py
9
  pinned: false
 
10
  ---
11
 
12
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: ReplaceAnything Using SAM + SDXL Inpainting
3
+ emoji: 📚
4
+ colorFrom: yellow
5
+ colorTo: blue
6
  sdk: gradio
7
+ sdk_version: 3.50.2
8
  app_file: app.py
9
  pinned: false
10
+ license: apache-2.0
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,311 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ##!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ # @Time : 2023-06-01
4
+ # @Author : ashui(Binghui Chen)
5
+ from sympy import im
6
+ import time
7
+ import cv2
8
+ import gradio as gr
9
+ import numpy as np
10
+ import random
11
+ import math
12
+ import uuid
13
+ import torch
14
+ from torch import autocast
15
+
16
+ from src.util import resize_image, upload_np_2_oss
17
+ from diffusers import AutoPipelineForInpainting, UNet2DConditionModel
18
+ import diffusers
19
+ import sys, os
20
+
21
+ from PIL import Image, ImageFilter, ImageOps, ImageDraw
22
+
23
+ from segment_anything import SamPredictor, sam_model_registry
24
+
25
+
26
+ device = "cuda" if torch.cuda.is_available() else "cpu"
27
+ pipe = AutoPipelineForInpainting.from_pretrained("diffusers/stable-diffusion-xl-1.0-inpainting-0.1", torch_dtype=torch.float16, variant="fp16").to(device)
28
+
29
+ mobile_sam = sam_model_registry['vit_h'](checkpoint='models/sam_vit_h_4b8939.pth').to("cuda")
30
+ mobile_sam.eval()
31
+ mobile_predictor = SamPredictor(mobile_sam)
32
+ colors = [(255, 0, 0), (0, 255, 0)]
33
+ markers = [1, 5]
34
+
35
+ # - - - - - examples - - - - - #
36
+ # 输入图地址, 文本, 背景图地址, index, []
37
+ image_examples = [
38
+ ["imgs/000.jpg", "A young woman in short sleeves shows off a mobile phone", None, 0, []],
39
+ ["imgs/001.jpg", "A young woman wears short sleeves, her hand is holding a bottle.", None, 1, []],
40
+ ["imgs/003.png", "A woman is wearing a black suit against a blue background", "imgs/003_bg.jpg", 2, []],
41
+ ["imgs/002.png", "A young woman poses in a dress, she stands in front of a blue background", "imgs/002_bg.png", 3, []],
42
+ ["imgs/bg_gen/base_imgs/1cdb9b1e6daea6a1b85236595d3e43d6.png", "water splash", None, 4, []],
43
+ ["imgs/bg_gen/base_imgs/1cdb9b1e6daea6a1b85236595d3e43d6.png", "", "imgs/bg_gen/ref_imgs/df9a93ac2bca12696a9166182c4bf02ad9679aa5.jpg", 5, []],
44
+ ["imgs/bg_gen/base_imgs/IMG_2941.png", "On the desert floor", None, 6, []],
45
+ ["imgs/bg_gen/base_imgs/b2b1ed243364473e49d2e478e4f24413.png","White ground, white background, light coming in, Canon",None,7,[]],
46
+ ]
47
+
48
+ img = "image_gallery/"
49
+ files = os.listdir(img)
50
+ files = sorted(files)
51
+ showcases = []
52
+ for idx, name in enumerate(files):
53
+ temp = os.path.join(os.path.dirname(__file__), img, name)
54
+ showcases.append(temp)
55
+
56
+ def process(original_image, original_mask, input_mask, selected_points, prompt,negative_prompt,guidance_scale,steps,strength,scheduler):
57
+ if original_image.shape[0]>original_image.shape[1]:
58
+ original_image=cv2.resize(original_image,(int(original_image.shape[1]*1000/original_image.shape[0]),1000))
59
+ if original_mask.shape[0]>original_mask.shape[1]:
60
+ original_mask=cv2.resize(original_mask,(int(original_mask.shape[1]*1000/original_mask.shape[0]),1000))
61
+ if original_image is None:
62
+ raise gr.Error('Please upload the input image')
63
+ if (original_mask is None or len(selected_points)==0) and input_mask is None:
64
+ raise gr.Error("Please click the region where you want to keep unchanged, or upload a white-black Mask image where white color indicates region to be retained.")
65
+
66
+ # load example image
67
+ if isinstance(original_image, int):
68
+ image_name = image_examples[original_image][0]
69
+ original_image = cv2.imread(image_name)
70
+ original_image = cv2.cvtColor(original_image, cv2.COLOR_BGR2RGB)
71
+
72
+ if input_mask is not None:
73
+ H,W=original_image.shape[:2]
74
+ original_mask = cv2.resize(input_mask, (W, H))
75
+ else:
76
+ original_mask = np.clip(255 - original_mask, 0, 255).astype(np.uint8)
77
+
78
+ request_id = str(uuid.uuid4())
79
+ # input_image_url = upload_np_2_oss(original_image, request_id+".png")
80
+ # input_mask_url = upload_np_2_oss(original_mask, request_id+"_mask.png")
81
+ # source_background_url = "" if source_background is None else upload_np_2_oss(source_background, request_id+"_bg.png")
82
+ if negative_prompt == "":
83
+ negative_prompt = None
84
+ scheduler_class_name = scheduler.split("-")[0]
85
+
86
+ add_kwargs = {}
87
+ if len(scheduler.split("-")) > 1:
88
+ add_kwargs["use_karras"] = True
89
+ if len(scheduler.split("-")) > 2:
90
+ add_kwargs["algorithm_type"] = "sde-dpmsolver++"
91
+
92
+ scheduler = getattr(diffusers, scheduler_class_name)
93
+ pipe.scheduler = scheduler.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", subfolder="scheduler", **add_kwargs)
94
+
95
+ # Image.fromarray(original_mask).save("original_mask.png")
96
+ init_image = Image.fromarray(original_image).convert("RGB")
97
+ mask = Image.fromarray(original_mask).convert("RGB")
98
+ output = pipe(prompt = prompt, negative_prompt=negative_prompt, image=init_image, mask_image=mask, guidance_scale=guidance_scale, num_inference_steps=int(steps), strength=strength)
99
+ # person detect: [[x1,y1,x2,y2,score],]
100
+ # det_res = call_person_detect(input_image_url)
101
+
102
+ res = []
103
+ # if len(det_res)>0:
104
+ # if len(prompt)==0:
105
+ # raise gr.Error('Please input the prompt')
106
+ # # res = call_virtualmodel(input_image_url, input_mask_url, source_background_url, prompt, face_prompt)
107
+ # else:
108
+ # ###
109
+ # if len(prompt)==0:
110
+ # prompt=None
111
+ # ref_image_url=None if source_background_url =='' else source_background_url
112
+ # original_mask=original_mask[:,:,:1]
113
+ # base_image=np.concatenate([original_image, original_mask],axis=2)
114
+ # base_image_url=upload_np_2_oss(base_image, request_id+"_base.png")
115
+ # res=call_bg_genration(base_image_url,ref_image_url,prompt,ref_prompt_weight=0.5)
116
+ # Image.fromarray(input_mask).save("input_mask.png")
117
+ res= output.images[0]
118
+ res = res.convert("RGB")
119
+ #resize the output image to original image size
120
+ res = res.resize((original_image.shape[1],original_image.shape[0]), Image.LANCZOS)
121
+ return [res], request_id, True
122
+
123
+ block = gr.Blocks(
124
+ css="css/style.css",
125
+ theme=gr.themes.Soft(
126
+ radius_size=gr.themes.sizes.radius_none,
127
+ text_size=gr.themes.sizes.text_md
128
+ )
129
+ ).queue(concurrency_count=2)
130
+ with block:
131
+ with gr.Row():
132
+ with gr.Column():
133
+ gr.HTML(f"""
134
+ </br>
135
+ <div class="baselayout" style="text-shadow: white 0.01rem 0.01rem 0.4rem; position:fixed; z-index: 9999; top:0; left:0;right:0; background-size:100% 100%">
136
+ <h1 style="text-align:center; color:Black; font-size:3rem; position: relative;"> SAM + SDXL Inpainting </h1>
137
+ </div>
138
+ </br>
139
+ </br>
140
+ <div style="text-align: center;">
141
+ <h1 >ReplaceAnything using SAM + SDXL Inpainting as you want: Ultra-high quality content replacement</h1>
142
+ </div>
143
+ """)
144
+
145
+ with gr.Tabs(elem_classes=["Tab"]):
146
+ with gr.TabItem("Image Create"):
147
+ with gr.Accordion(label="🧭 Instructions:", open=True, elem_id="accordion"):
148
+ with gr.Row(equal_height=True):
149
+ gr.Markdown("""
150
+ - ⭐️ <b>step1:</b>Upload or select one image from Example
151
+ - ⭐️ <b>step2:</b>Click on Input-image to select the object to be retained (or upload a white-black Mask image, in which white color indicates the region you want to keep unchanged)
152
+ - ⭐️ <b>step3:</b>Input prompt or reference image (highly-recommended) for generating new contents
153
+ - ⭐️ <b>step4:</b>Click Run button
154
+ """)
155
+ with gr.Row():
156
+ with gr.Column():
157
+ with gr.Column(elem_id="Input"):
158
+ with gr.Row():
159
+ with gr.Tabs(elem_classes=["feedback"]):
160
+ with gr.TabItem("Input Image"):
161
+ input_image = gr.Image(type="numpy", label="input",scale=2)
162
+ original_image = gr.State(value=None,label="index")
163
+ original_mask = gr.State(value=None)
164
+ selected_points = gr.State([],label="click points")
165
+ with gr.Row(elem_id="Seg"):
166
+ radio = gr.Radio(['foreground', 'background'], label='Click to seg: ', value='foreground',scale=2)
167
+ undo_button = gr.Button('Undo seg', elem_id="btnSEG",scale=1)
168
+ input_mask = gr.Image(type="numpy", label="Mask Image")
169
+ prompt = gr.Textbox(label="Prompt", placeholder="Please input your prompt",value='',lines=1)
170
+ negative_prompt = gr.Textbox(label="Negative Prompt", placeholder="Please input your prompt",value='hand,blur,face,bad',lines=1)
171
+ guidance_scale = gr.Number(value=7.5, minimum=1.0, maximum=20.0, step=0.1, label="guidance_scale")
172
+ steps = gr.Number(value=20, minimum=10, maximum=30, step=1, label="steps")
173
+ strength = gr.Number(value=0.99, minimum=0.01, maximum=1.0, step=0.01, label="strength")
174
+ with gr.Row(mobile_collapse=False, equal_height=True):
175
+ schedulers = ["DEISMultistepScheduler", "HeunDiscreteScheduler", "EulerDiscreteScheduler", "DPMSolverMultistepScheduler", "DPMSolverMultistepScheduler-Karras", "DPMSolverMultistepScheduler-Karras-SDE"]
176
+ scheduler = gr.Dropdown(label="Schedulers", choices=schedulers, value="EulerDiscreteScheduler")
177
+
178
+ run_button = gr.Button("Run",elem_id="btn")
179
+
180
+ with gr.Column():
181
+ with gr.Tabs(elem_classes=["feedback"]):
182
+ with gr.TabItem("Outputs"):
183
+ result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery", preview=True)
184
+ # recommend=gr.Button("Recommend results to Image Gallery",elem_id="recBut")
185
+ request_id=gr.State(value="")
186
+ gallery_flag=gr.State(value=False)
187
+
188
+ # once user upload an image, the original image is stored in `original_image`
189
+ def store_img(img):
190
+ # image upload is too slow
191
+ # if min(img.shape[0], img.shape[1]) > 896:
192
+ # img = resize_image(img, 896)
193
+ # if max(img.shape[0], img.shape[1])*1.0/min(img.shape[0], img.shape[1])>2.0:
194
+ # raise gr.Error('image aspect ratio cannot be larger than 2.0')
195
+ return img, img, [], None # when new image is uploaded, `selected_points` should be empty
196
+
197
+ input_image.upload(
198
+ store_img,
199
+ [input_image],
200
+ [input_image, original_image, selected_points]
201
+ )
202
+
203
+ # user click the image to get points, and show the points on the image
204
+ def segmentation(img, sel_pix):
205
+ print("segmentation")
206
+ # online show seg mask
207
+ points = []
208
+ labels = []
209
+ for p, l in sel_pix:
210
+ points.append(p)
211
+ labels.append(l)
212
+ mobile_predictor.set_image(img if isinstance(img, np.ndarray) else np.array(img))
213
+ with torch.no_grad():
214
+ with autocast("cuda"):
215
+ masks, _, _ = mobile_predictor.predict(point_coords=np.array(points), point_labels=np.array(labels), multimask_output=False)
216
+
217
+ output_mask = np.ones((masks.shape[1], masks.shape[2], 3))*255
218
+ for i in range(3):
219
+ output_mask[masks[0] == True, i] = 0.0
220
+
221
+ mask_all = np.ones((masks.shape[1], masks.shape[2], 3))
222
+ color_mask = np.random.random((1, 3)).tolist()[0]
223
+ for i in range(3):
224
+ mask_all[masks[0] == True, i] = color_mask[i]
225
+ masked_img = img / 255 * 0.3 + mask_all * 0.7
226
+ masked_img = masked_img*255
227
+ ## draw points
228
+ for point, label in sel_pix:
229
+ cv2.drawMarker(masked_img, point, colors[label], markerType=markers[label], markerSize=20, thickness=5)
230
+ return masked_img, output_mask
231
+
232
+ def get_point(img, sel_pix, point_type, evt: gr.SelectData):
233
+
234
+ if point_type == 'foreground':
235
+ sel_pix.append((evt.index, 1)) # append the foreground_point
236
+ elif point_type == 'background':
237
+ sel_pix.append((evt.index, 0)) # append the background_point
238
+ else:
239
+ sel_pix.append((evt.index, 1)) # default foreground_point
240
+
241
+ if isinstance(img, int):
242
+ image_name = image_examples[img][0]
243
+ img = cv2.imread(image_name)
244
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
245
+
246
+ # online show seg mask
247
+ if img.shape[0]>img.shape[1]:
248
+ img=cv2.resize(img,(int(img.shape[1]*1000/img.shape[0]),1000))
249
+ masked_img, output_mask = segmentation(img, sel_pix)
250
+
251
+ return masked_img.astype(np.uint8), output_mask
252
+
253
+ input_image.select(
254
+ get_point,
255
+ [original_image, selected_points, radio],
256
+ [input_image, original_mask],
257
+ )
258
+
259
+ # undo the selected point
260
+ def undo_points(orig_img, sel_pix):
261
+ # draw points
262
+ output_mask = None
263
+ if len(sel_pix) != 0:
264
+ if isinstance(orig_img, int): # if orig_img is int, the image if select from examples
265
+ temp = cv2.imread(image_examples[orig_img][0])
266
+ temp = cv2.cvtColor(temp, cv2.COLOR_BGR2RGB)
267
+ else:
268
+ temp = orig_img.copy()
269
+ sel_pix.pop()
270
+ # online show seg mask
271
+ if len(sel_pix) !=0:
272
+ temp, output_mask = segmentation(temp, sel_pix)
273
+ return temp.astype(np.uint8), output_mask
274
+ else:
275
+ gr.Error("Nothing to Undo")
276
+
277
+ undo_button.click(
278
+ undo_points,
279
+ [original_image, selected_points],
280
+ [input_image, original_mask]
281
+ )
282
+
283
+ def upload_to_img_gallery(img, res, re_id, flag):
284
+ if flag:
285
+ gr.Info("Image uploading")
286
+ if isinstance(img, int):
287
+ image_name = image_examples[img][0]
288
+ img = cv2.imread(image_name)
289
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
290
+ _ = upload_np_2_oss(img, name=re_id+"_ori.jpg", gallery=True)
291
+ for idx, r in enumerate(res):
292
+ r = cv2.imread(r['name'])
293
+ r = cv2.cvtColor(r, cv2.COLOR_BGR2RGB)
294
+ _ = upload_np_2_oss(r, name=re_id+f"_res_{idx}.jpg", gallery=True)
295
+ flag=False
296
+ gr.Info("Images have beend uploaded and are under check")
297
+ else:
298
+ gr.Info("Nothing to to")
299
+ return flag
300
+
301
+ # recommend.click(
302
+ # upload_to_img_gallery,
303
+ # [original_image, result_gallery, request_id, gallery_flag],
304
+ # [gallery_flag]
305
+ # )
306
+ # ips=[input_image, original_image, original_mask, input_mask, selected_points, prompt,negative_prompt,guidance_scale,steps,strength,scheduler]
307
+ ips=[original_image, original_mask, input_mask, selected_points, prompt,negative_prompt,guidance_scale,steps,strength,scheduler]
308
+ run_button.click(fn=process, inputs=ips, outputs=[result_gallery, request_id, gallery_flag])
309
+
310
+
311
+ block.launch(share=True)
css/0.png ADDED
css/style.css ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ .baselayout{
3
+ background: url('https://img.alicdn.com/imgextra/i1/O1CN016hd0V91ilWY5Xr24B_!!6000000004453-2-tps-2882-256.png') no-repeat;
4
+ }
5
+ #btn {
6
+ background-color: #336699;
7
+ color: white;
8
+ }
9
+ #recBut {
10
+ background-color: #bb5252;
11
+ color: white;
12
+ width: 30%;
13
+ margin: auto;
14
+ }
15
+ #btnSEG {
16
+ background-color: #D5F3F4;
17
+ color: black;
18
+ }
19
+ #btnCHAT {
20
+ background-color: #B6DBF2;
21
+ color: black;
22
+ }
23
+ #accordion {
24
+ background-color: transparent;
25
+ }
26
+ #accordion1 {
27
+ background-color: #ecedee;
28
+ }
29
+ .feedback button.selected{
30
+ background-color: #6699CC;
31
+ color: white !important;
32
+ }
33
+ .feedback1 button.selected{
34
+ background-color: #839ab2;
35
+ color: white !important;
36
+ }
37
+ .Tab button.selected{
38
+ color: red;
39
+ font-weight: bold;
40
+ }
41
+ #Image {
42
+ width: 80%;
43
+ margin:auto;
44
+ }
45
+ #ShowCase {
46
+ width: 30%;
47
+ flex:none !important;
48
+ }
49
+
50
+ #Input {
51
+ border-style:solid;
52
+ border-width:1px;
53
+ border-color:#000000
54
+ }
55
+ #Seg {
56
+ min-width: min(100px, 100%) !important;
57
+ width: 100%;
58
+ margin:auto;
59
+ }
image_gallery/00.png ADDED
image_gallery/01.png ADDED
image_gallery/02.png ADDED
image_gallery/a001.jpg ADDED
image_gallery/a002.jpg ADDED
image_gallery/a003.jpg ADDED
image_gallery/a004.jpg ADDED
image_gallery/a005.jpg ADDED
image_gallery/a006.jpg ADDED
image_gallery/a007.jpg ADDED
image_gallery/a009.jpg ADDED
image_gallery/bg_001.jpg ADDED
image_gallery/bg_002.jpg ADDED
image_gallery/bg_003.jpg ADDED
image_gallery/bg_004.jpg ADDED
image_gallery/bg_005.jpg ADDED
image_gallery/bg_006.jpg ADDED
image_gallery/bg_007.jpg ADDED
image_gallery/bg_008.jpg ADDED
image_gallery/bg_009.jpg ADDED
image_gallery/bg_010.jpg ADDED
image_gallery/bg_012.jpg ADDED
imgs/000.jpg ADDED
imgs/001.jpg ADDED
imgs/002.png ADDED
imgs/002_bg.png ADDED
imgs/003.png ADDED
imgs/003_bg.jpg ADDED
imgs/bg_gen/base_imgs/1cdb9b1e6daea6a1b85236595d3e43d6.png ADDED
imgs/bg_gen/base_imgs/IMG_2941.png ADDED
imgs/bg_gen/base_imgs/b2b1ed243364473e49d2e478e4f24413.png ADDED
imgs/bg_gen/ref_imgs/df9a93ac2bca12696a9166182c4bf02ad9679aa5.jpg ADDED
models/DOWNLOAD_MODEL_HERE.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ 模型链接
2
+ https://vision-poster.oss-cn-shanghai.aliyuncs.com/ashui/sam_vit_h_4b8939.pth?OSSAccessKeyId=LTAI5tSPYbksBzcmooNHCYif&Expires=3599001703148669&Signature=TYznO77DKFjGNn92SnR9RbucOlU%3D
models/sam_vit_h_4b8939.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a7bf3b02f3ebf1267aba913ff637d9a2d5c33d3173bb679e46d9f338c26f262e
3
+ size 2564550879
requirements.txt ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dashscope
2
+ sympy
3
+ Pillow==9.5.0
4
+ gradio==3.50.0
5
+ opencv-python
6
+ omegaconf
7
+ sentencepiece
8
+ easydict
9
+ scikit-image
10
+ git+https://github.com/facebookresearch/segment-anything.git
11
+ torch
12
+ torchvision
13
+ oss2==2.17.0
14
+ --extra-index-url https://download.pytorch.org/whl/cu118
15
+ torch
16
+ git+https://github.com/huggingface/diffusers.git
17
+ transformers
18
+ accelerate
19
+ ftfy
20
+ numpy
21
+ matplotlib
22
+ uuid
23
+ opencv-python
sdxl.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ --extra-index-url https://download.pytorch.org/whl/cu118
2
+ torch
3
+ git+https://github.com/huggingface/diffusers.git
4
+ transformers
5
+ accelerate
6
+ ftfy
7
+ numpy
8
+ matplotlib
9
+ uuid
10
+ opencv-python
src/__init__.py ADDED
File without changes
src/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (153 Bytes). View file
 
src/__pycache__/background_generation.cpython-38.pyc ADDED
Binary file (2.52 kB). View file
 
src/__pycache__/log.cpython-38.pyc ADDED
Binary file (741 Bytes). View file
 
src/__pycache__/person_detect.cpython-38.pyc ADDED
Binary file (1.2 kB). View file
 
src/__pycache__/util.cpython-38.pyc ADDED
Binary file (4.86 kB). View file
 
src/__pycache__/virtualmodel.cpython-38.pyc ADDED
Binary file (2.44 kB). View file
 
src/background_generation.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy
3
+ from PIL import Image
4
+ import requests
5
+ import urllib.request
6
+ from http import HTTPStatus
7
+ from datetime import datetime
8
+ import json
9
+ from .log import logger
10
+ import time
11
+ import gradio as gr
12
+ from .util import download_images
13
+
14
+ def call_bg_genration(base_image, ref_img, prompt,ref_prompt_weight=0.5):
15
+ API_KEY = os.getenv("API_KEY_BG_GENERATION")
16
+ BATCH_SIZE=4
17
+ headers = {
18
+ "Content-Type": "application/json",
19
+ "Accept": "application/json",
20
+ "Authorization": f"Bearer {API_KEY}",
21
+ "X-DashScope-Async": "enable",
22
+ }
23
+ data = {
24
+ "model": "wanx-background-generation-v2",
25
+ "input":{
26
+ "base_image_url": base_image,
27
+ 'ref_image_url':ref_img,
28
+ "ref_prompt": prompt,
29
+ },
30
+ "parameters": {
31
+ "ref_prompt_weight": ref_prompt_weight,
32
+ "n": BATCH_SIZE
33
+ }
34
+ }
35
+ url_create_task = 'https://dashscope.aliyuncs.com/api/v1/services/aigc/background-generation/generation'
36
+ res_ = requests.post(url_create_task, data=json.dumps(data), headers=headers)
37
+
38
+ respose_code = res_.status_code
39
+ if 200 == respose_code:
40
+ res = json.loads(res_.content.decode())
41
+ request_id = res['request_id']
42
+ task_id = res['output']['task_id']
43
+ logger.info(f"task_id: {task_id}: Create Background Generation request success. Params: {data}")
44
+
45
+ # 异步查询
46
+ is_running = True
47
+ while is_running:
48
+ url_query = f'https://dashscope.aliyuncs.com/api/v1/tasks/{task_id}'
49
+ res_ = requests.post(url_query, headers=headers)
50
+ respose_code = res_.status_code
51
+ if 200 == respose_code:
52
+ res = json.loads(res_.content.decode())
53
+ if "SUCCEEDED" == res['output']['task_status']:
54
+ logger.info(f"task_id: {task_id}: Background generation task query success.")
55
+ results = res['output']['results']
56
+ img_urls = [x['url'] for x in results]
57
+ logger.info(f"task_id: {task_id}: {res}")
58
+ break
59
+ elif "FAILED" != res['output']['task_status']:
60
+ logger.debug(f"task_id: {task_id}: query result...")
61
+ time.sleep(1)
62
+ else:
63
+ raise gr.Error('Fail to get results from Background Generation task.')
64
+
65
+ else:
66
+ logger.error(f'task_id: {task_id}: Fail to query task result: {res_.content}')
67
+ raise gr.Error("Fail to query task result.")
68
+
69
+ logger.info(f"task_id: {task_id}: download generated images.")
70
+ img_data = download_images(img_urls, BATCH_SIZE)
71
+ logger.info(f"task_id: {task_id}: Generate done.")
72
+ return img_data
73
+ else:
74
+ logger.error(f'Fail to create Background Generation task: {res_.content}')
75
+ raise gr.Error("Fail to create Background Generation task.")
76
+
src/log.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from logging.handlers import RotatingFileHandler
3
+ import os
4
+
5
+ log_file_name = "workdir/log_replaceAnything.log"
6
+ os.makedirs(os.path.dirname(log_file_name), exist_ok=True)
7
+
8
+ format = '[%(levelname)s] %(asctime)s "%(filename)s", line %(lineno)d, %(message)s'
9
+ logging.basicConfig(
10
+ format=format,
11
+ datefmt="%Y-%m-%d %H:%M:%S",
12
+ level=logging.INFO)
13
+ logger = logging.getLogger(name="WordArt_Studio")
14
+
15
+ fh = RotatingFileHandler(log_file_name, maxBytes=20000000, backupCount=3)
16
+ formatter = logging.Formatter(format, datefmt="%Y-%m-%d %H:%M:%S")
17
+ fh.setFormatter(formatter)
18
+ logger.addHandler(fh)
src/person_detect.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy
3
+ from PIL import Image
4
+ import requests
5
+ import urllib.request
6
+ from http import HTTPStatus
7
+ from datetime import datetime
8
+ import json
9
+ from .log import logger
10
+ import time
11
+ import gradio as gr
12
+ from .util import download_images
13
+
14
+ API_KEY = os.getenv("API_KEY_VIRTUALMODEL")
15
+
16
+ def call_person_detect(input_image_url):
17
+ headers = {
18
+ "Content-Type": "application/json",
19
+ "Accept": "application/json",
20
+ "Authorization": f"Bearer {API_KEY}",
21
+ "X-DashScope-DataInspection": "enable",
22
+ }
23
+ data = {
24
+ "model": "body-detection",
25
+ "input":{
26
+ "image_url": input_image_url,
27
+ },
28
+ "parameters": {
29
+ "score_threshold": 0.6,
30
+ }
31
+ }
32
+ url_create_task = 'https://dashscope.aliyuncs.com/api/v1/services/vision/bodydetection/detect'
33
+ res_ = requests.post(url_create_task, data=json.dumps(data), headers=headers)
34
+
35
+
36
+ res = json.loads(res_.content.decode())
37
+ request_id = res['request_id']
38
+ results = res['output']['results']
39
+ return results