File size: 3,521 Bytes
6a26cf0
 
3b71537
 
 
6a26cf0
 
7219aaf
fc8e9c1
6a26cf0
 
 
7a1a910
e317ed2
475d8cf
6a26cf0
e317ed2
7219aaf
e317ed2
6a26cf0
7a2e256
6a26cf0
27324e6
6a26cf0
7a2e256
6a26cf0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64f5869
6a26cf0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c1ae928
6a26cf0
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
# Gradio and other necessary imports
import gradio as gr
import subprocess

subprocess.run(["bash", "setup.sh"])
from segment_anything import SamPredictor, sam_model_registry
from diffusers import StableDiffusionInpaintPipeline
from GroundingDINO.groundingdino.util.inference import load_model, load_image, predict, annotate
from GroundingDINO.groundingdino.util import box_ops
from PIL import Image
import torch
import numpy as np

import os
device = torch.device("cpu")
# ----SAM

print("path", os.getcwd())

model_type = "vit_h"
predictor = SamPredictor(sam_model_registry[model_type](checkpoint="./GroundingDINO/weights/sam_vit_h_4b8939.pth").to(device))
# ------Stable Diffusion
pipe = StableDiffusionInpaintPipeline.from_pretrained("stabilityai/stable-diffusion-2-inpainting", torch_dtype=torch.float32).to(device)
# ----Grounding DINO
groundingdino_model = load_model("./GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py", "./GroundingDINO/weights/groundingdino_swint_ogc.pth")

BOX_TRESHOLD = 0.3
TEXT_TRESHOLD = 0.25

def show_mask(mask, image, random_color=True):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.8])], axis=0)
    else:
        color = np.array([30/255, 144/255, 255/255, 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)

    annotated_frame_pil = Image.fromarray(image).convert("RGBA")
    mask_image_pil = Image.fromarray((mask_image.cpu().numpy() * 255).astype(np.uint8)).convert("RGBA")

    return np.array(Image.alpha_composite(annotated_frame_pil, mask_image_pil))

def process_boxes(boxes, src):
    H, W, _ = src.shape
    boxes_xyxy = box_ops.box_cxcywh_to_xyxy(boxes) * torch.Tensor([W, H, W, H])
    return predictor.transform.apply_boxes_torch(boxes_xyxy, src.shape[:2]).to(device)

def edit_image(path: str, item: str, prompt=str, box_threshold=BOX_TRESHOLD, text_threshold=TEXT_TRESHOLD):
    src, img = load_image(path)
    boxes, logits, phrases = predict(
        model=groundingdino_model,
        image=img,
        caption=item,
        box_threshold=box_threshold,
        text_threshold=text_threshold
    )
    predictor.set_image(src)
    new_boxes = process_boxes(boxes, src)
    masks, _, _ = predictor.predict_torch(
        point_coords=None,
        point_labels=None,
        boxes=new_boxes,
        multimask_output=False,
    )
    img_annotated_mask = show_mask(masks[0][0].cpu(),
        annotate(image_source=src, boxes=boxes, logits=logits, phrases=phrases)[...,::-1]
    )
    return pipe(prompt=prompt,
        image=Image.fromarray(src).resize((512, 512)),
        mask_image=Image.fromarray(masks[0][0].cpu().numpy()).resize((512, 512))
    ).images[0]

# Define the Gradio interface
iface = gr.Interface(
    fn=edit_image, 
    inputs=[
        gr.inputs.Textbox(label="Image Path"),
        gr.inputs.Textbox(label="Caption"),
    ], 
    outputs=gr.outputs.Image(type="numpy"),
)

iface = gr.Interface(
    fn=edit_image, 
    inputs=[
        gr.inputs.Image(type="filepath", label="Upload Image"),
        gr.inputs.Textbox(label="Item"),
        gr.inputs.Textbox(label="Prompt"),
        gr.inputs.Slider(minimum=0.0, maximum=1.0, step=0.01, default=0.3, label="Box Threshold"),
        gr.inputs.Slider(minimum=0.0, maximum=1.0, step=0.01, default=0.2, label="Text Threshold")
    ], 
    outputs=gr.outputs.Image(type="numpy"),
)

iface.launch(inbrowser=True)


# path = './fire3.jpg'
# edit_image(path, "fire hydrant", "phone booth", 0.5, 0.2)