Arulkumar03 commited on
Commit
bffd5b1
1 Parent(s): 148168e

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +98 -0
app.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import copy
3
+
4
+ from IPython.display import display
5
+ from PIL import Image, ImageDraw, ImageFont
6
+ from torchvision.ops import box_convert
7
+
8
+ # Grounding DINO
9
+ import groundingdino.datasets.transforms as T
10
+ from groundingdino.models import build_model
11
+ from groundingdino.util import box_ops
12
+ from groundingdino.util.slconfig import SLConfig
13
+ from groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap
14
+ from groundingdino.util.inference import annotate, load_image, predict
15
+
16
+ import supervision as sv
17
+
18
+ # segment anything
19
+ from segment_anything import build_sam, SamPredictor
20
+ import cv2
21
+ import numpy as np
22
+ import matplotlib.pyplot as plt
23
+
24
+
25
+ # diffusers
26
+ import PIL
27
+ import requests
28
+ import torch
29
+ from io import BytesIO
30
+ from diffusers import StableDiffusionInpaintPipeline
31
+ from huggingface_hub import hf_hub_download
32
+
33
+ def load_model_hf(repo_id, filename, ckpt_config_filename, device='cpu'):
34
+ cache_config_file = hf_hub_download(repo_id=repo_id, filename=ckpt_config_filename)
35
+
36
+ args = SLConfig.fromfile(cache_config_file)
37
+ args.device = device
38
+ model = build_model(args)
39
+
40
+ cache_file = hf_hub_download(repo_id=repo_id, filename=filename)
41
+ checkpoint = torch.load(cache_file, map_location=device)
42
+ log = model.load_state_dict(clean_state_dict(checkpoint['model']), strict=False)
43
+ print("Model loaded from {} \n => {}".format(cache_file, log))
44
+ _ = model.eval()
45
+ return model
46
+
47
+ ckpt_repo_id = "ShilongLiu/GroundingDINO"
48
+ ckpt_filenmae = "groundingdino_swinb_cogcoor.pth"
49
+ ckpt_config_filename = "GroundingDINO_SwinB.cfg.py"
50
+
51
+ groundingdino_model = load_model_hf(ckpt_repo_id, ckpt_filenmae, ckpt_config_filename, device)
52
+
53
+ checkpoint = 'sam_vit_h_4b8939.pth'
54
+
55
+ predictor = SamPredictor(build_sam(checkpoint=checkpoint).to(device))
56
+
57
+ # detect object using grounding DINO
58
+ def detect(image, text_prompt, model, box_threshold = 0.3, text_threshold = 0.25):
59
+ boxes, logits, phrases = predict(
60
+ model=model,
61
+ image=image,
62
+ caption=text_prompt,
63
+ box_threshold=box_threshold,
64
+ text_threshold=text_threshold
65
+ )
66
+
67
+ annotated_frame = annotate(image_source=image_source, boxes=boxes, logits=logits, phrases=phrases)
68
+ annotated_frame = annotated_frame[...,::-1] # BGR to RGB
69
+ return annotated_frame, boxes
70
+
71
+
72
+ import gradio as gr
73
+
74
+ # Define the Gradio interface
75
+ def detect_objects(image, text_prompt):
76
+ # Convert Gradio input format to the format expected by the code
77
+ image_array = np.array(image)
78
+ image_source, _ = load_image(image_array)
79
+
80
+ # Detect objects using grounding DINO
81
+ annotated_frame, detected_boxes = detect(image_array, text_prompt, groundingdino_model)
82
+
83
+ # Convert the annotated frame to Gradio output format
84
+ annotated_image = Image.fromarray(annotated_frame)
85
+
86
+ return annotated_image
87
+
88
+ # Create the Gradio interface
89
+ iface = gr.Interface(
90
+ fn=detect_objects,
91
+ inputs=[gr.Image(), "text"],
92
+ outputs=gr.Image(),
93
+ live=True,
94
+ interpretation="default"
95
+ )
96
+
97
+ # Launch the Gradio interface
98
+ iface.launch()