# Gradio and other necessary imports import gradio as gr import subprocess subprocess.run(["bash", "setup.sh"]) from segment_anything import SamPredictor, sam_model_registry from diffusers import StableDiffusionInpaintPipeline from GroundingDINO.groundingdino.util.inference import load_model, load_image, predict, annotate from GroundingDINO.groundingdino.util import box_ops from PIL import Image import torch import numpy as np import os device = torch.device("cpu") # ----SAM print("path", os.getcwd()) model_type = "vit_h" predictor = SamPredictor(sam_model_registry[model_type](checkpoint="./GroundingDINO/weights/sam_vit_h_4b8939.pth").to(device)) # ------Stable Diffusion pipe = StableDiffusionInpaintPipeline.from_pretrained("stabilityai/stable-diffusion-2-inpainting", torch_dtype=torch.float32).to(device) # ----Grounding DINO groundingdino_model = load_model("./GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py", "./GroundingDINO/weights/groundingdino_swint_ogc.pth") BOX_TRESHOLD = 0.3 TEXT_TRESHOLD = 0.25 def show_mask(mask, image, random_color=True): if random_color: color = np.concatenate([np.random.random(3), np.array([0.8])], axis=0) else: color = np.array([30/255, 144/255, 255/255, 0.6]) h, w = mask.shape[-2:] mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) annotated_frame_pil = Image.fromarray(image).convert("RGBA") mask_image_pil = Image.fromarray((mask_image.cpu().numpy() * 255).astype(np.uint8)).convert("RGBA") return np.array(Image.alpha_composite(annotated_frame_pil, mask_image_pil)) def process_boxes(boxes, src): H, W, _ = src.shape boxes_xyxy = box_ops.box_cxcywh_to_xyxy(boxes) * torch.Tensor([W, H, W, H]) return predictor.transform.apply_boxes_torch(boxes_xyxy, src.shape[:2]).to(device) def edit_image(path: str, item: str, prompt=str, box_threshold=BOX_TRESHOLD, text_threshold=TEXT_TRESHOLD): src, img = load_image(path) boxes, logits, phrases = predict( model=groundingdino_model, image=img, caption=item, box_threshold=box_threshold, text_threshold=text_threshold ) predictor.set_image(src) new_boxes = process_boxes(boxes, src) masks, _, _ = predictor.predict_torch( point_coords=None, point_labels=None, boxes=new_boxes, multimask_output=False, ) img_annotated_mask = show_mask(masks[0][0].cpu(), annotate(image_source=src, boxes=boxes, logits=logits, phrases=phrases)[...,::-1] ) return pipe(prompt=prompt, image=Image.fromarray(src).resize((512, 512)), mask_image=Image.fromarray(masks[0][0].cpu().numpy()).resize((512, 512)) ).images[0] # Define the Gradio interface iface = gr.Interface( fn=edit_image, inputs=[ gr.inputs.Textbox(label="Image Path"), gr.inputs.Textbox(label="Caption"), ], outputs=gr.outputs.Image(type="numpy"), ) iface = gr.Interface( fn=edit_image, inputs=[ gr.inputs.Image(type="filepath", label="Upload Image"), gr.inputs.Textbox(label="Item"), gr.inputs.Textbox(label="Prompt"), gr.inputs.Slider(minimum=0.0, maximum=1.0, step=0.01, default=0.3, label="Box Threshold"), gr.inputs.Slider(minimum=0.0, maximum=1.0, step=0.01, default=0.2, label="Text Threshold") ], outputs=gr.outputs.Image(type="numpy"), ) iface.launch(inbrowser=True) # path = './fire3.jpg' # edit_image(path, "fire hydrant", "phone booth", 0.5, 0.2)