annading commited on
Commit
adab5b0
1 Parent(s): 2081ef8

updated inference, testing batch size

Browse files
Files changed (2) hide show
  1. app.py +1 -9
  2. dino_sam.py +10 -7
app.py CHANGED
@@ -7,8 +7,6 @@ os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:5000'
7
  subprocess.run(['pip', 'install', '-e', 'GroundingDINO'])
8
  sys.path.append(os.path.join(os.getcwd(), "GroundingDINO"))
9
  sys.path.append(os.path.join(os.getcwd(), "segment_anything"))
10
- # os.system("wget https://huggingface.co/ShilongLiu/GroundingDINO/resolve/main/groundingdino_swint_ogc.pth")
11
- # os.system("wget https://huggingface.co/spaces/mrtlive/segment-anything-model/resolve/main/sam_vit_h_4b8939.pth")
12
 
13
  import gradio as gr
14
  from dino_sam import sam_dino_vid
@@ -43,12 +41,6 @@ with gr.Blocks() as demo:
43
  """
44
  )
45
 
46
- gr.HTML(
47
- """
48
- <p="left">
49
- The csv contains frame numbers and timestamps, bounding box coordinates, and number of detections per frame.</p>
50
- """
51
- )
52
  with gr.Row():
53
  with gr.Column():
54
  input = gr.Video(label="Input Video", interactive=True)
@@ -74,7 +66,7 @@ with gr.Blocks() as demo:
74
  step=1)
75
  video_options = gr.CheckboxGroup(choices=["Bounding boxes", "Masks"],
76
  label="Video Output Options",
77
- info="Select the options to display in the output video.",
78
  value=["Bounding boxes"],
79
  interactive=True)
80
 
 
7
  subprocess.run(['pip', 'install', '-e', 'GroundingDINO'])
8
  sys.path.append(os.path.join(os.getcwd(), "GroundingDINO"))
9
  sys.path.append(os.path.join(os.getcwd(), "segment_anything"))
 
 
10
 
11
  import gradio as gr
12
  from dino_sam import sam_dino_vid
 
41
  """
42
  )
43
 
 
 
 
 
 
 
44
  with gr.Row():
45
  with gr.Column():
46
  input = gr.Video(label="Input Video", interactive=True)
 
66
  step=1)
67
  video_options = gr.CheckboxGroup(choices=["Bounding boxes", "Masks"],
68
  label="Video Output Options",
69
+ info="Select the options to display in the output video. Note: if masks are selected, runtime will increase.",
70
  value=["Bounding boxes"],
71
  interactive=True)
72
 
dino_sam.py CHANGED
@@ -8,7 +8,7 @@ import torch
8
  import csv
9
  # import pstats
10
  import warnings
11
- # from memory_profiler import profile
12
  # from pstats import SortKey
13
  from tqdm import tqdm
14
  from torchvision.ops import box_convert
@@ -26,6 +26,7 @@ def prepare_image(image, transform, device):
26
  image = torch.as_tensor(image, device=device.device)
27
  return image.permute(2, 0, 1).contiguous()
28
 
 
29
  def sam_dino_vid(
30
  vid_path: str,
31
  text_prompt: str,
@@ -36,7 +37,7 @@ def sam_dino_vid(
36
  config_path: str = "GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py",
37
  weights_path: str = "weights/groundingdino_swint_ogc.pth",
38
  device: str = 'cuda',
39
- batch_size: int = 5
40
  ) -> (str, str):
41
  """ Args:
42
  Returns:
@@ -101,13 +102,13 @@ def sam_dino_vid(
101
 
102
  annotated_frame_paths = [os.path.join(frames_dir, os.path.basename(frame_path)) for frame_path in batch_paths]
103
  # convert images_orig to rgb from bgr
104
- images_orig = [cv2.cvtColor(image, cv2.COLOR_BGR2RGB) for image in images_orig]
105
 
106
  if masks_needed:
107
  # run SAM in batches on boxes from dino
108
  batched_input = []
109
  sam_boxes = []
110
- for image, box in zip(images_orig, boxes_i):
111
  height, width = image.shape[:2]
112
  # convert the boxes from groundingDINO format to SAM format
113
  box = box * torch.Tensor([width, height, width, height])
@@ -123,7 +124,7 @@ def sam_dino_vid(
123
  # write to annotated_frames_dir for stitching
124
  mask = prediction["masks"].cpu().numpy()
125
  box = sam_boxes[i].cpu().numpy()
126
- annotated_frame = plot_sam(images_orig[i], mask, box, boxes_shown=boxes_needed)
127
  cv2.imwrite(annotated_frame_paths[i], annotated_frame)
128
 
129
  elif boxes_needed and not masks_needed:
@@ -215,6 +216,8 @@ def plot_sam(
215
  return image
216
 
217
  # if __name__ == '__main__':
 
 
218
  # start_time = datetime.datetime.now()
219
- # sam_dino_vid("baboon_15s.mp4", "baboon", box_threshold=0.3, text_threshold=0.3, fps_processed=30, video_options=['Bounding boxes', 'Masks'])
220
- # print("elapsed: " + str(datetime.datetime.now() - start_time))
 
8
  import csv
9
  # import pstats
10
  import warnings
11
+ from memory_profiler import profile
12
  # from pstats import SortKey
13
  from tqdm import tqdm
14
  from torchvision.ops import box_convert
 
26
  image = torch.as_tensor(image, device=device.device)
27
  return image.permute(2, 0, 1).contiguous()
28
 
29
+ # @profile
30
  def sam_dino_vid(
31
  vid_path: str,
32
  text_prompt: str,
 
37
  config_path: str = "GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py",
38
  weights_path: str = "weights/groundingdino_swint_ogc.pth",
39
  device: str = 'cuda',
40
+ batch_size: int = 10
41
  ) -> (str, str):
42
  """ Args:
43
  Returns:
 
102
 
103
  annotated_frame_paths = [os.path.join(frames_dir, os.path.basename(frame_path)) for frame_path in batch_paths]
104
  # convert images_orig to rgb from bgr
105
+ images_orig_rgb = [cv2.cvtColor(image, cv2.COLOR_BGR2RGB) for image in images_orig]
106
 
107
  if masks_needed:
108
  # run SAM in batches on boxes from dino
109
  batched_input = []
110
  sam_boxes = []
111
+ for image, box in zip(images_orig_rgb, boxes_i):
112
  height, width = image.shape[:2]
113
  # convert the boxes from groundingDINO format to SAM format
114
  box = box * torch.Tensor([width, height, width, height])
 
124
  # write to annotated_frames_dir for stitching
125
  mask = prediction["masks"].cpu().numpy()
126
  box = sam_boxes[i].cpu().numpy()
127
+ annotated_frame = plot_sam(images_orig_rgb[i], mask, box, boxes_shown=boxes_needed)
128
  cv2.imwrite(annotated_frame_paths[i], annotated_frame)
129
 
130
  elif boxes_needed and not masks_needed:
 
216
  return image
217
 
218
  # if __name__ == '__main__':
219
+ # def run_sam_dino_vid():
220
+ # sam_dino_vid("baboon_15s.mp4", "baboon", box_threshold=0.3, text_threshold=0.3, fps_processed=30, video_options=['Bounding boxes', 'Masks'])
221
  # start_time = datetime.datetime.now()
222
+ # stats = run_sam_dino_vid()
223
+ # print("elapsed: " + str(datetime.datetime.now() - start_time))