Spaces:
Runtime error
Runtime error
| # Gradio and other necessary imports | |
| import gradio as gr | |
| import subprocess | |
| subprocess.run(["bash", "setup.sh"]) | |
| from segment_anything import SamPredictor, sam_model_registry | |
| from diffusers import StableDiffusionInpaintPipeline | |
| from GroundingDINO.groundingdino.util.inference import load_model, load_image, predict, annotate | |
| from GroundingDINO.groundingdino.util import box_ops | |
| from PIL import Image | |
| import torch | |
| import numpy as np | |
| import os | |
| device = torch.device("cpu") | |
| # ----SAM | |
| print("path", os.getcwd()) | |
| model_type = "vit_h" | |
| predictor = SamPredictor(sam_model_registry[model_type](checkpoint="./GroundingDINO/weights/sam_vit_h_4b8939.pth").to(device)) | |
| # ------Stable Diffusion | |
| pipe = StableDiffusionInpaintPipeline.from_pretrained("stabilityai/stable-diffusion-2-inpainting", torch_dtype=torch.float32).to(device) | |
| # ----Grounding DINO | |
| groundingdino_model = load_model("./GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py", "./GroundingDINO/weights/groundingdino_swint_ogc.pth") | |
| BOX_TRESHOLD = 0.3 | |
| TEXT_TRESHOLD = 0.25 | |
| def show_mask(mask, image, random_color=True): | |
| if random_color: | |
| color = np.concatenate([np.random.random(3), np.array([0.8])], axis=0) | |
| else: | |
| color = np.array([30/255, 144/255, 255/255, 0.6]) | |
| h, w = mask.shape[-2:] | |
| mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) | |
| annotated_frame_pil = Image.fromarray(image).convert("RGBA") | |
| mask_image_pil = Image.fromarray((mask_image.cpu().numpy() * 255).astype(np.uint8)).convert("RGBA") | |
| return np.array(Image.alpha_composite(annotated_frame_pil, mask_image_pil)) | |
| def process_boxes(boxes, src): | |
| H, W, _ = src.shape | |
| boxes_xyxy = box_ops.box_cxcywh_to_xyxy(boxes) * torch.Tensor([W, H, W, H]) | |
| return predictor.transform.apply_boxes_torch(boxes_xyxy, src.shape[:2]).to(device) | |
| def edit_image(path: str, item: str, prompt=str, box_threshold=BOX_TRESHOLD, text_threshold=TEXT_TRESHOLD): | |
| src, img = load_image(path) | |
| boxes, logits, phrases = predict( | |
| model=groundingdino_model, | |
| image=img, | |
| caption=item, | |
| box_threshold=box_threshold, | |
| text_threshold=text_threshold | |
| ) | |
| predictor.set_image(src) | |
| new_boxes = process_boxes(boxes, src) | |
| masks, _, _ = predictor.predict_torch( | |
| point_coords=None, | |
| point_labels=None, | |
| boxes=new_boxes, | |
| multimask_output=False, | |
| ) | |
| img_annotated_mask = show_mask(masks[0][0].cpu(), | |
| annotate(image_source=src, boxes=boxes, logits=logits, phrases=phrases)[...,::-1] | |
| ) | |
| return pipe(prompt=prompt, | |
| image=Image.fromarray(src).resize((512, 512)), | |
| mask_image=Image.fromarray(masks[0][0].cpu().numpy()).resize((512, 512)) | |
| ).images[0] | |
| # Define the Gradio interface | |
| iface = gr.Interface( | |
| fn=edit_image, | |
| inputs=[ | |
| gr.inputs.Textbox(label="Image Path"), | |
| gr.inputs.Textbox(label="Caption"), | |
| ], | |
| outputs=gr.outputs.Image(type="numpy"), | |
| ) | |
| iface = gr.Interface( | |
| fn=edit_image, | |
| inputs=[ | |
| gr.inputs.Image(type="filepath", label="Upload Image"), | |
| gr.inputs.Textbox(label="Item"), | |
| gr.inputs.Textbox(label="Prompt"), | |
| gr.inputs.Slider(minimum=0.0, maximum=1.0, step=0.01, default=0.3, label="Box Threshold"), | |
| gr.inputs.Slider(minimum=0.0, maximum=1.0, step=0.01, default=0.2, label="Text Threshold") | |
| ], | |
| outputs=gr.outputs.Image(type="numpy"), | |
| ) | |
| iface.launch(inbrowser=True) | |
| # path = './fire3.jpg' | |
| # edit_image(path, "fire hydrant", "phone booth", 0.5, 0.2) |