File size: 12,238 Bytes
e59b5af
f03eb6a
 
 
 
 
 
 
53522ad
87f24e2
61b0358
f03eb6a
 
 
 
 
e59b5af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5eb1f06
f03eb6a
 
e59b5af
 
 
 
 
 
 
ea91b1d
 
e59b5af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55b186b
e59b5af
 
57320f0
e59b5af
57320f0
e61e65e
e59b5af
 
 
 
 
 
 
 
 
55b186b
e59b5af
 
 
 
 
5eb1f06
 
 
 
 
 
8117ec8
e59b5af
 
 
 
 
 
 
0dcd925
 
 
e59b5af
 
 
 
4fcb5c4
e59b5af
 
 
 
 
 
 
 
4fcb5c4
e59b5af
 
f03eb6a
e59b5af
0248402
e59b5af
 
 
 
 
 
 
 
 
842dc5c
e59b5af
 
 
0248402
e59b5af
 
 
 
 
 
 
 
 
842dc5c
e59b5af
 
 
9ae11a7
0248402
e59b5af
 
0248402
e59b5af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f03eb6a
e59b5af
f03eb6a
 
 
 
 
55b186b
 
 
 
0248402
 
c08c82f
e59b5af
5eb1f06
f03eb6a
 
 
 
 
0248402
0dcd925
0248402
e397c20
e59b5af
fdc6af0
 
 
 
1afd24f
fdc6af0
1afd24f
 
 
 
 
9a1b44f
a8b0bec
1afd24f
 
 
5eb1f06
 
 
 
 
 
 
 
 
1afd24f
842dc5c
1afd24f
 
77338e2
1afd24f
 
 
 
 
 
 
 
 
842dc5c
1afd24f
a8b0bec
 
e397c20
842dc5c
1afd24f
 
 
 
 
 
 
 
 
 
 
 
 
98be1d8
6f9fbcc
f03eb6a
 
 
32bf33c
 
bb29d52
 
b781658
5ae8b97
98be1d8
 
 
3d59103
 
 
5eb1f06
 
 
a8b0bec
 
 
 
e397c20
 
 
a8b0bec
 
 
f6d702d
9a1b44f
5ae8b97
 
bb29d52
1b09f65
 
b781658
5eb1f06
32bf33c
166176b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
import os
import warnings
warnings.filterwarnings('ignore')

import subprocess, io, os, sys, time

os.system("pip install -q gradio")
os.system("pip install -q diffusers")
os.system("pip install -q segment_anything")
os.system("pip install accelerate")
# os.environ["CUDA_VISIBLE_DEVICES"] = "0"

result = subprocess.run(['pip', 'install', '-e', 'GroundingDINO'], check=True)
print(f'pip install GroundingDINO = {result}')
sys.path.insert(0, './GroundingDINO')


'''Importing Libraries'''
import os

import groundingdino.datasets.transforms as T
import numpy as np
import torch
from groundingdino.models import build_model
from groundingdino.util import box_ops
from groundingdino.util.inference import predict
from groundingdino.util.slconfig import SLConfig
from groundingdino.util.utils import clean_state_dict
from huggingface_hub import hf_hub_download
from segment_anything import sam_model_registry
from segment_anything import SamPredictor
from diffusers import StableDiffusionInpaintPipeline, AutoPipelineForInpainting

from scipy.ndimage import binary_dilation

import cv2
import matplotlib.pyplot as plt
from PIL import Image
from torchvision.utils import draw_bounding_boxes
from torchvision.utils import draw_segmentation_masks

torch.set_default_dtype(torch.float32)


def load_model_hf(repo_id, filename, ckpt_config_filename, device='cpu'):
    '''
        Loads model from hugging face, we use it to get grounding dino model checkpoints
    '''
    cache_config_file = hf_hub_download(repo_id=repo_id, filename=ckpt_config_filename)

    args = SLConfig.fromfile(cache_config_file) 
    model = build_model(args)
    args.device = device

    cache_file = hf_hub_download(repo_id=repo_id, filename=filename)
    checkpoint = torch.load(cache_file, map_location='cpu')
    log = model.load_state_dict(clean_state_dict(checkpoint['model']), strict=False)
    model.eval()
    return model  


def transform_image(image) -> torch.Tensor:
    transform = T.Compose([
        T.RandomResize([800], max_size=1333),
        T.ToTensor(),
        T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        
    ])
    image_transformed, _ = transform(image, None)
    return image_transformed


class CFG:
    '''
        Defines variables used in our code
    '''
    # sam_type = "vit_h"
    SAM_MODELS = {
        "vit_h": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",
        "vit_l": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth",
        "vit_b": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth"
    }
    INPAINTING_MODELS = {
        "Stable Diffusion" : "runwayml/stable-diffusion-inpainting",
        "Stable Diffusion 2" : "stabilityai/stable-diffusion-2-inpainting",
        "Stable Diffusion XL" : "diffusers/stable-diffusion-xl-1.0-inpainting-0.1"
    }
    
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    ckpt_repo_id = "ShilongLiu/GroundingDINO"
    ckpt_filename = "groundingdino_swinb_cogcoor.pth"
    ckpt_config_filename = "GroundingDINO_SwinB.cfg.py"



'''Build models'''
def build_sam(sam_type):
    checkpoint_url = CFG.SAM_MODELS[sam_type]
    sam = sam_model_registry[sam_type]()
    state_dict = torch.hub.load_state_dict_from_url(checkpoint_url)
    sam.load_state_dict(state_dict, strict=True)
    sam.to(device = CFG.device)
    sam = SamPredictor(sam)
    print('SAM is built !')
    return sam


def build_groundingdino():
    ckpt_repo_id = CFG.ckpt_repo_id
    ckpt_filename = CFG.ckpt_filename
    ckpt_config_filename = CFG.ckpt_config_filename
    groundingdino = load_model_hf(ckpt_repo_id, ckpt_filename, ckpt_config_filename)
    print('Grounding DINO is built !')
    return groundingdino


'''Predictions'''
def predict_dino(image_pil, text_prompt, box_threshold, text_threshold, model_groundingdino):
    image_trans = transform_image(image_pil)
    boxes, logits, phrases = predict(model = model_groundingdino,
                                     image = image_trans,
                                     caption = text_prompt,
                                     box_threshold = box_threshold,
                                     text_threshold = text_threshold,
                                     device = CFG.device)
    W, H = image_pil.size
    boxes = box_ops.box_cxcywh_to_xyxy(boxes) * torch.Tensor([W, H, W, H]) # center cood to corner cood
    print('DINO prediction done !')
    return boxes, logits, phrases


def predict_sam(image_pil, boxes, model_sam):
    image_array = np.asarray(image_pil)
    model_sam.set_image(image_array)
    transformed_boxes = model_sam.transform.apply_boxes_torch(boxes, image_array.shape[:2])
    masks, _, _ = model_sam.predict_torch(
        point_coords=None,
        point_labels=None,
        boxes=transformed_boxes.to(model_sam.device),
        multimask_output=False,
    )
    print('SAM prediction done !')
    return masks.cpu()


def mask_predict(image_pil, text_prompt, box_threshold, text_threshold, models):
    boxes, logits, phrases = predict_dino(image_pil, text_prompt, box_threshold, text_threshold, models[0])
    masks = torch.tensor([])
    if len(boxes) > 0:
        masks = predict_sam(image_pil, boxes, models[1])
        masks = masks.squeeze(1)
    return masks, boxes, phrases, logits


'''Utils'''

def load_image(image_path):
    return Image.open(image_path).convert("RGB")


def draw_image(image_pil, masks, boxes, alpha=0.4):
    image = np.asarray(image_pil)
    image = torch.from_numpy(image).permute(2, 0, 1)
    if len(masks) > 0:
        image = draw_segmentation_masks(image, masks=masks, colors=['red'] * len(masks), alpha=alpha)
    return image.numpy().transpose(1, 2, 0)

# torch.save(masks, 'masks.pt')


'''Visualise segmented results'''

def visualize_results(img1, img2, task):
    fig, axes = plt.subplots(1, 2, figsize=(20, 10))

    axes[0].imshow(img1)
    axes[0].set_title('Original Image')

    axes[1].imshow(img2)
    axes[1].set_title(f'{text_prompt} : {task}')

    for ax in axes:
        ax.axis('off')

# visualize_results(image_pil, output, 'segmented')

# x_units = 200
# y_units = -100
# text_prompt = 'wooden stool'
# image_path = '/kaggle/input/avataar/stool.jpeg'
# output_image_path = '/kaggle/working'

def build_models(sam_type):
    model_sam = build_sam(sam_type)
    model_groundingdino = build_groundingdino()
    models = [model_groundingdino, model_sam]
    return models
    

def main_fun(image_pil, x_units, y_units, text_prompt, box_threshold, text_threshold, inpaint_text_prompt, num_inference_steps, sam_type, inpainting_model):
#     x_units = 200
#     y_units = -100
#     text_prompt = 'wooden stool'
    
#     image_pil = load_image(image_path)
    models = build_models(sam_type)

    masks, boxes, phrases, logits = mask_predict(image_pil, text_prompt, box_threshold, text_threshold, models)
    segmented_image = draw_image(image_pil, masks, boxes, alpha=0.4)

    # Combined all segmentation masks
    combined_mask = torch.sum(masks, axis=0)
    combined_mask = np.where(combined_mask[:, :] != 0, True, False)
    
    '''Get masked object and background as two separate images'''
    mask = np.expand_dims(combined_mask, axis=-1)
    masked_object = image_pil * mask
    background = image_pil * ~mask


    '''Shifts image by x_units and y_units'''
    M = np.float32([[1, 0, x_units], [0, 1, -y_units]])
    shifted_image = cv2.warpAffine(masked_object, M, (masked_object.shape[1] , masked_object.shape[0]), borderMode=cv2.BORDER_CONSTANT, borderValue=(0, 0, 0))
    masked_shifted_image = np.where(shifted_image[:, :, 0] != 0, True, False)

    '''Load stable diffuser model at checkpoint finetuned for inpainting task'''
    inpainting_model_path = CFG.INPAINTING_MODELS[inpainting_model]

    if inpainting_model=='Stable Diffusion XL':
        pipe = AutoPipelineForInpainting.from_pretrained(inpainting_model_path, 
                      torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32)
    else:
        pipe = StableDiffusionInpaintPipeline.from_pretrained(inpainting_model_path,
            torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32)
    
    pipe.to(CFG.device)
    print('StableDiffusion model loaded !')

    # With Dilation
    structuring_element = np.ones((15, 15, 1), dtype=bool)
    extrapolated_mask = binary_dilation(mask, structure=structuring_element)
    mask_as_uint8 = extrapolated_mask.astype(np.uint8) * 255
    pil_mask = Image.fromarray(mask_as_uint8.squeeze(), mode='L').resize((1024, 1024))

    # # Without Dilation
    # pil_background = Image.fromarray(background)
    # mask_as_uint8 = mask.astype(np.uint8) * 255
    # pil_mask = Image.fromarray(mask_as_uint8.squeeze(), mode='L')

    print('Image Inpainting in process.....')
    '''Do inpainting on masked locations of original image'''
    # prompt = 'fill as per background'
    prompt = inpaint_text_prompt
    inpainted_image = pipe(prompt=prompt, image=image_pil, mask_image=pil_mask, num_inference_steps=num_inference_steps).images[0]
    print('Image INPAINTED !')
    # inpainted_image

    '''Get composite of shifted object and background inpainted imaage'''
    pil_shifted_image = Image.fromarray(shifted_image).resize(inpainted_image.size)
    np_shifted_image = np.array(pil_shifted_image)
    masked_shifted_image = np.where(np_shifted_image[:, :, 0] != 0, True, False)
    masked_shifted_image = np.expand_dims(masked_shifted_image, axis=-1)
    inpainted_shifted = np.array(inpainted_image) * ~masked_shifted_image

    shifted_image = cv2.resize(shifted_image, inpainted_image.size)
    output = inpainted_shifted + shifted_image
    output = Image.fromarray(output)
#     visualize_results(image_pil, output, 'shifted')
    segmented_image = Image.fromarray(segmented_image)
    return segmented_image.resize(image_pil.size), output.resize(image_pil.size)

import gradio as gr

image_blocks = gr.Blocks()
with image_blocks as demo:
    with gr.Row():
        with gr.Column():
            image = gr.Image(sources=['upload'], type="pil", label="Upload")
            # with gr.Row(elem_id="prompt-container").style(mobile_collapse=False, equal_height=True):
            text_prompt = gr.Textbox(placeholder = 'Your prompt (what you want in place of what is erased)', label="Object class", show_label=True)
            x_units = gr.Slider(minimum=0, maximum=300, step=10, value=100, label="x_units")
            y_units = gr.Slider(minimum=0, maximum=300, step=10, value=0, label="y_units")
            sam_type = gr.Dropdown(
                    ["vit_h", "vit_l", "vit_b"], label="ViT base model for SAM", value="vit_h"
                )
            inpainting_model = gr.Dropdown(
                    ["Stable Diffusion", "Stable Diffusion 2", "Stable Diffusion XL"], label="Model for inpainting", value="Stable Diffusion 2"
                )
            with gr.Accordion("Advanced options", open=False) as advanced_options:
                box_threshold = gr.Slider(
                    label="Box Threshold", minimum=0.0, maximum=1.0, value=0.23, step=0.01
                )
                num_inference_steps = gr.Slider(
                    label="number of inference steps", minimum=20, maximum=100, value=20, step=10
                )
                text_threshold = gr.Slider(
                    label="Text Threshold", minimum=0.0, maximum=1.0, value=0.25, step=0.01
                )
                inpaint_text_prompt = gr.Textbox(placeholder = 'Your prompt (default=fill as per background)', value="fill as per background",  label="Prompt to replace object with", show_label=True)         

            # text_prompt = gr.Textbox(lines=1, label="Prompt")
            btn = gr.Button(value="Submit")
        with gr.Column():
            image_out_seg = gr.Image(label="Segmented object", height=400, width=400)
            image_out_shift = gr.Image(label="Shifted object", height=400, width=400)
            
    btn.click(fn=main_fun, inputs=[image, x_units, y_units, text_prompt, box_threshold, text_threshold, inpaint_text_prompt, num_inference_steps, sam_type, inpainting_model], outputs=[image_out_seg, image_out_shift])

image_blocks.launch(share=True)