Spaces:
Running
on
Zero
Running
on
Zero
remove mask generataion
Browse files
app.py
CHANGED
@@ -39,9 +39,12 @@ dtype = torch.bfloat16
|
|
39 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
40 |
base_model = "black-forest-labs/FLUX.1-dev"
|
41 |
|
|
|
|
|
|
|
42 |
|
43 |
-
FLORENCE_MODEL, FLORENCE_PROCESSOR = load_florence_model(device=device)
|
44 |
-
SAM_IMAGE_MODEL = load_sam_image_model(device=device)
|
45 |
|
46 |
|
47 |
class calculateDuration:
|
@@ -147,9 +150,7 @@ def run_flux(
|
|
147 |
) -> Image.Image:
|
148 |
print("Running FLUX...")
|
149 |
|
150 |
-
|
151 |
-
good_vae = AutoencoderKL.from_pretrained(base_model, subfolder="vae", torch_dtype=dtype).to(device)
|
152 |
-
pipe = FluxInpaintPipeline.from_pretrained(base_model, torch_dtype=dtype, vae=taef1).to(device)
|
153 |
|
154 |
with calculateDuration("load lora"):
|
155 |
print("start to load lora", lora_path, lora_weights)
|
@@ -178,62 +179,10 @@ def run_flux(
|
|
178 |
|
179 |
return genearte_image
|
180 |
|
181 |
-
@spaces.GPU(duration=10)
|
182 |
-
def genearte_mask(image_input: Image.Image, masking_prompt_text: str) -> Image.Image:
|
183 |
-
# generate mask by florence & sam
|
184 |
-
print("Generating mask...")
|
185 |
-
task_prompt = "<CAPTION_TO_PHRASE_GROUNDING>"
|
186 |
-
|
187 |
-
with calculateDuration("FLORENCE"):
|
188 |
-
print(task_prompt, masking_prompt_text)
|
189 |
-
_, result = run_florence_inference(
|
190 |
-
model=FLORENCE_MODEL,
|
191 |
-
processor=FLORENCE_PROCESSOR,
|
192 |
-
device=device,
|
193 |
-
image=image_input,
|
194 |
-
task=task_prompt,
|
195 |
-
text=masking_prompt_text
|
196 |
-
)
|
197 |
-
|
198 |
-
with calculateDuration("sv.Detections"):
|
199 |
-
# start to dectect
|
200 |
-
detections = sv.Detections.from_lmm(
|
201 |
-
lmm=sv.LMM.FLORENCE_2,
|
202 |
-
result=result,
|
203 |
-
resolution_wh=image_input.size
|
204 |
-
)
|
205 |
-
|
206 |
-
images = []
|
207 |
-
|
208 |
-
with calculateDuration("generate segmenet mask"):
|
209 |
-
# using sam generate segments images
|
210 |
-
detections = run_sam_inference(SAM_IMAGE_MODEL, image_input, detections)
|
211 |
-
if len(detections) == 0:
|
212 |
-
gr.Info("No objects detected.")
|
213 |
-
return None
|
214 |
-
print("mask generated:", len(detections.mask))
|
215 |
-
kernel_size = dilate
|
216 |
-
kernel = np.ones((kernel_size, kernel_size), np.uint8)
|
217 |
-
|
218 |
-
for i in range(len(detections.mask)):
|
219 |
-
mask = detections.mask[i].astype(np.uint8) * 255
|
220 |
-
images.append(mask)
|
221 |
-
|
222 |
-
# merge mark into on image
|
223 |
-
merged_mask = np.zeros_like(images[0], dtype=np.uint8)
|
224 |
-
for mask in images:
|
225 |
-
merged_mask = cv2.bitwise_or(merged_mask, mask)
|
226 |
-
|
227 |
-
images = [merged_mask]
|
228 |
-
|
229 |
-
return images[0]
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
def process(
|
234 |
image_url: str,
|
|
|
235 |
inpainting_prompt_text: str,
|
236 |
-
masking_prompt_text: str,
|
237 |
mask_inflation_slider: int,
|
238 |
mask_blur_slider: int,
|
239 |
seed_slicer: int,
|
@@ -260,26 +209,16 @@ def process(
|
|
260 |
result["message"] = "invalid inpainting prompt"
|
261 |
return json.dumps(result)
|
262 |
|
263 |
-
if not masking_prompt_text:
|
264 |
-
gr.Info("Please enter masking_prompt_text.")
|
265 |
-
result["message"] = "invalid masking prompt"
|
266 |
-
return json.dumps(result)
|
267 |
|
268 |
with calculateDuration("load image"):
|
269 |
image = load_image(image_url)
|
|
|
270 |
|
271 |
-
|
272 |
-
|
273 |
-
if not image:
|
274 |
-
gr.Info("Please upload an image.")
|
275 |
result["message"] = "can not load image"
|
276 |
return json.dumps(result)
|
277 |
|
278 |
-
if is_mask_empty(mask):
|
279 |
-
gr.Info("Please draw a mask or enter a masking prompt.")
|
280 |
-
result["message"] = "can not generate mask"
|
281 |
-
return json.dumps(result)
|
282 |
-
|
283 |
# generate
|
284 |
width, height = calculate_image_dimensions_for_flux(original_resolution_wh=image.size)
|
285 |
image = image.resize((width, height), Image.LANCZOS)
|
@@ -321,11 +260,11 @@ with gr.Blocks() as demo:
|
|
321 |
container=False,
|
322 |
)
|
323 |
|
324 |
-
|
325 |
-
label="
|
326 |
show_label=False,
|
327 |
max_lines=1,
|
328 |
-
placeholder="Enter
|
329 |
container=False,
|
330 |
)
|
331 |
|
@@ -439,8 +378,8 @@ with gr.Blocks() as demo:
|
|
439 |
fn=process,
|
440 |
inputs=[
|
441 |
image_url,
|
|
|
442 |
inpainting_prompt_text_component,
|
443 |
-
masking_prompt_text_component,
|
444 |
mask_inflation_slider_component,
|
445 |
mask_blur_slider_component,
|
446 |
seed_slicer_component,
|
|
|
39 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
40 |
base_model = "black-forest-labs/FLUX.1-dev"
|
41 |
|
42 |
+
taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype).to(device)
|
43 |
+
good_vae = AutoencoderKL.from_pretrained(base_model, subfolder="vae", torch_dtype=dtype).to(device)
|
44 |
+
pipe = FluxInpaintPipeline.from_pretrained(base_model, torch_dtype=dtype, vae=taef1).to(device)
|
45 |
|
46 |
+
# FLORENCE_MODEL, FLORENCE_PROCESSOR = load_florence_model(device=device)
|
47 |
+
# SAM_IMAGE_MODEL = load_sam_image_model(device=device)
|
48 |
|
49 |
|
50 |
class calculateDuration:
|
|
|
150 |
) -> Image.Image:
|
151 |
print("Running FLUX...")
|
152 |
|
153 |
+
|
|
|
|
|
154 |
|
155 |
with calculateDuration("load lora"):
|
156 |
print("start to load lora", lora_path, lora_weights)
|
|
|
179 |
|
180 |
return genearte_image
|
181 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
182 |
def process(
|
183 |
image_url: str,
|
184 |
+
mask_url: str,
|
185 |
inpainting_prompt_text: str,
|
|
|
186 |
mask_inflation_slider: int,
|
187 |
mask_blur_slider: int,
|
188 |
seed_slicer: int,
|
|
|
209 |
result["message"] = "invalid inpainting prompt"
|
210 |
return json.dumps(result)
|
211 |
|
|
|
|
|
|
|
|
|
212 |
|
213 |
with calculateDuration("load image"):
|
214 |
image = load_image(image_url)
|
215 |
+
mask = load_image(mask_url)
|
216 |
|
217 |
+
if not image or not mask:
|
218 |
+
gr.Info("Please upload an image & mask by url.")
|
|
|
|
|
219 |
result["message"] = "can not load image"
|
220 |
return json.dumps(result)
|
221 |
|
|
|
|
|
|
|
|
|
|
|
222 |
# generate
|
223 |
width, height = calculate_image_dimensions_for_flux(original_resolution_wh=image.size)
|
224 |
image = image.resize((width, height), Image.LANCZOS)
|
|
|
260 |
container=False,
|
261 |
)
|
262 |
|
263 |
+
mask_url = gr.Text(
|
264 |
+
label="image url of masking",
|
265 |
show_label=False,
|
266 |
max_lines=1,
|
267 |
+
placeholder="Enter url of masking",
|
268 |
container=False,
|
269 |
)
|
270 |
|
|
|
378 |
fn=process,
|
379 |
inputs=[
|
380 |
image_url,
|
381 |
+
mask_url,
|
382 |
inpainting_prompt_text_component,
|
|
|
383 |
mask_inflation_slider_component,
|
384 |
mask_blur_slider_component,
|
385 |
seed_slicer_component,
|