Arulkumar03 commited on
Commit
3acc94f
1 Parent(s): e5e70ea

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +106 -71
app.py CHANGED
@@ -1,98 +1,133 @@
1
  import argparse
2
- import copy
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
- from IPython.display import display
5
- from PIL import Image, ImageDraw, ImageFont
6
- from torchvision.ops import box_convert
7
 
8
- # Grounding DINO
9
- import groundingdino.datasets.transforms as T
10
  from groundingdino.models import build_model
11
- from groundingdino.util import box_ops
12
  from groundingdino.util.slconfig import SLConfig
13
- from groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap
14
  from groundingdino.util.inference import annotate, load_image, predict
 
15
 
16
- import supervision as sv
17
 
18
- # segment anything
19
- from segment_anything import build_sam, SamPredictor
20
- import cv2
21
- import numpy as np
22
- import matplotlib.pyplot as plt
23
 
24
 
25
- # diffusers
26
- import PIL
27
- import requests
28
- import torch
29
- from io import BytesIO
30
- from diffusers import StableDiffusionInpaintPipeline
31
- from huggingface_hub import hf_hub_download
32
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
33
- def load_model_hf(repo_id, filename, ckpt_config_filename, device='cpu'):
34
- cache_config_file = hf_hub_download(repo_id=repo_id, filename=ckpt_config_filename)
35
 
36
- args = SLConfig.fromfile(cache_config_file)
37
- args.device = device
 
38
  model = build_model(args)
 
39
 
40
  cache_file = hf_hub_download(repo_id=repo_id, filename=filename)
41
- checkpoint = torch.load(cache_file, map_location=device)
42
  log = model.load_state_dict(clean_state_dict(checkpoint['model']), strict=False)
43
  print("Model loaded from {} \n => {}".format(cache_file, log))
44
  _ = model.eval()
45
- return model
46
-
47
- ckpt_repo_id = "ShilongLiu/GroundingDINO"
48
- ckpt_filenmae = "groundingdino_swinb_cogcoor.pth"
49
- ckpt_config_filename = "GroundingDINO_SwinB.cfg.py"
50
 
51
- groundingdino_model = load_model_hf(ckpt_repo_id, ckpt_filenmae, ckpt_config_filename, device)
 
 
 
 
 
 
 
52
 
53
- checkpoint = 'sam_vit_h_4b8939.pth'
 
 
 
 
 
54
 
55
- predictor = SamPredictor(build_sam(checkpoint=checkpoint).to(device))
56
 
57
- # detect object using grounding DINO
58
- def detect(image, text_prompt, model, box_threshold = 0.3, text_threshold = 0.25):
59
- boxes, logits, phrases = predict(
60
- model=model,
61
- image=image,
62
- caption=text_prompt,
63
- box_threshold=box_threshold,
64
- text_threshold=text_threshold
65
- )
66
 
67
- annotated_frame = annotate(image_source=image_source, boxes=boxes, logits=logits, phrases=phrases)
68
- annotated_frame = annotated_frame[...,::-1] # BGR to RGB
69
- return annotated_frame, boxes
70
 
 
 
 
 
71
 
72
- import gradio as gr
73
 
74
- # Define the Gradio interface
75
- def detect_objects(image, text_prompt):
76
- # Convert Gradio input format to the format expected by the code
77
- image_array = np.array(image)
78
- image_source, _ = load_image(image_array)
79
-
80
- # Detect objects using grounding DINO
81
- annotated_frame, detected_boxes = detect(image_array, text_prompt, groundingdino_model)
82
 
83
- # Convert the annotated frame to Gradio output format
84
- annotated_image = Image.fromarray(annotated_frame)
85
-
86
- return annotated_image
87
-
88
- # Create the Gradio interface
89
- iface = gr.Interface(
90
- fn=detect_objects,
91
- inputs=[gr.Image(), "text"],
92
- outputs=gr.Image(),
93
- live=True,
94
- interpretation="default"
95
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
 
97
- # Launch the Gradio interface
98
- iface.launch()
 
1
  import argparse
2
+ from functools import partial
3
+ import cv2
4
+ import requests
5
+ import os
6
+ from io import BytesIO
7
+ from PIL import Image
8
+ import numpy as np
9
+ from pathlib import Path
10
+ import gradio as gr
11
+
12
+ import warnings
13
+
14
+ import torch
15
+
16
+ os.system("python setup.py build develop --user")
17
+ os.system("pip install packaging==21.3")
18
+ warnings.filterwarnings("ignore")
19
 
 
 
 
20
 
 
 
21
  from groundingdino.models import build_model
 
22
  from groundingdino.util.slconfig import SLConfig
23
+ from groundingdino.util.utils import clean_state_dict
24
  from groundingdino.util.inference import annotate, load_image, predict
25
+ import groundingdino.datasets.transforms as T
26
 
27
+ from huggingface_hub import hf_hub_download
28
 
 
 
 
 
 
29
 
30
 
31
+ # Use this command for evaluate the GLIP-T model
32
+ config_file = "groundingdino/config/GroundingDINO_SwinT_OGC.py"
33
+ ckpt_repo_id = "ShilongLiu/GroundingDINO"
34
+ ckpt_filenmae = "groundingdino_swint_ogc.pth"
 
 
 
 
 
 
35
 
36
+
37
+ def load_model_hf(model_config_path, repo_id, filename, device='cpu'):
38
+ args = SLConfig.fromfile(model_config_path)
39
  model = build_model(args)
40
+ args.device = device
41
 
42
  cache_file = hf_hub_download(repo_id=repo_id, filename=filename)
43
+ checkpoint = torch.load(cache_file, map_location='cpu')
44
  log = model.load_state_dict(clean_state_dict(checkpoint['model']), strict=False)
45
  print("Model loaded from {} \n => {}".format(cache_file, log))
46
  _ = model.eval()
47
+ return model
 
 
 
 
48
 
49
+ def image_transform_grounding(init_image):
50
+ transform = T.Compose([
51
+ T.RandomResize([800], max_size=1333),
52
+ T.ToTensor(),
53
+ T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
54
+ ])
55
+ image, _ = transform(init_image, None) # 3, h, w
56
+ return init_image, image
57
 
58
+ def image_transform_grounding_for_vis(init_image):
59
+ transform = T.Compose([
60
+ T.RandomResize([800], max_size=1333),
61
+ ])
62
+ image, _ = transform(init_image, None) # 3, h, w
63
+ return image
64
 
65
+ model = load_model_hf(config_file, ckpt_repo_id, ckpt_filenmae)
66
 
67
+ def run_grounding(input_image, grounding_caption, box_threshold, text_threshold):
68
+ init_image = input_image.convert("RGB")
69
+ original_size = init_image.size
 
 
 
 
 
 
70
 
71
+ _, image_tensor = image_transform_grounding(init_image)
72
+ image_pil: Image = image_transform_grounding_for_vis(init_image)
 
73
 
74
+ # run grounidng
75
+ boxes, logits, phrases = predict(model, image_tensor, grounding_caption, box_threshold, text_threshold, device='cpu')
76
+ annotated_frame = annotate(image_source=np.asarray(image_pil), boxes=boxes, logits=logits, phrases=phrases)
77
+ image_with_box = Image.fromarray(cv2.cvtColor(annotated_frame, cv2.COLOR_BGR2RGB))
78
 
 
79
 
80
+ return image_with_box
 
 
 
 
 
 
 
81
 
82
+ if __name__ == "__main__":
83
+
84
+ parser = argparse.ArgumentParser("Grounding DINO demo", add_help=True)
85
+ parser.add_argument("--debug", action="store_true", help="using debug mode")
86
+ parser.add_argument("--share", action="store_true", help="share the app")
87
+ args = parser.parse_args()
88
+ css = """
89
+ #mkd {
90
+ height: 500px;
91
+ overflow: auto;
92
+ border: 1px solid #ccc;
93
+ }
94
+ """
95
+ block = gr.Blocks(css=css).queue()
96
+ with block:
97
+ gr.Markdown("<h1><center>Grounding DINO<h1><center>")
98
+ gr.Markdown("<h3><center>Open-World Detection with <a href='https://github.com/IDEA-Research/GroundingDINO'>Grounding DINO</a><h3><center>")
99
+ gr.Markdown("<h3><center>Note the model runs on CPU, so it may take a while to run the model.<h3><center>")
100
+
101
+ with gr.Row():
102
+ with gr.Column():
103
+ input_image = gr.Image(source='upload', type="pil")
104
+ grounding_caption = gr.Textbox(label="Detection Prompt")
105
+ run_button = gr.Button(label="Run")
106
+ with gr.Accordion("Advanced options", open=False):
107
+ box_threshold = gr.Slider(
108
+ label="Box Threshold", minimum=0.0, maximum=1.0, value=0.25, step=0.001
109
+ )
110
+ text_threshold = gr.Slider(
111
+ label="Text Threshold", minimum=0.0, maximum=1.0, value=0.25, step=0.001
112
+ )
113
+
114
+ with gr.Column():
115
+ gallery = gr.outputs.Image(
116
+ type="pil",
117
+ # label="grounding results"
118
+ ).style(full_width=True, full_height=True)
119
+ # gallery = gr.Gallery(label="Generated images", show_label=False).style(
120
+ # grid=[1], height="auto", container=True, full_width=True, full_height=True)
121
+
122
+ run_button.click(fn=run_grounding, inputs=[
123
+ input_image, grounding_caption, box_threshold, text_threshold], outputs=[gallery])
124
+ gr.Examples(
125
+ [["this_is_fine.png", "coffee cup", 0.25, 0.25]],
126
+ inputs = [input_image, grounding_caption, box_threshold, text_threshold],
127
+ outputs = [gallery],
128
+ fn=run_grounding,
129
+ cache_examples=True,
130
+ label='Try this example input!'
131
+ )
132
+ block.launch(share=False, show_api=False, show_error=True)
133