File size: 4,089 Bytes
82ee3e2
 
80eb764
82ee3e2
 
 
80eb764
 
080fa88
 
82ee3e2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80eb764
 
 
 
 
 
 
 
82ee3e2
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import subprocess
import os 
import sys
# set CUDA_HOME
os.environ["CUDA_HOME"] = "/usr/local/cuda-11.8/"
subprocess.run(['pip', 'install', '-e', 'GroundingDINO'])
sys.path.append(os.path.join(os.getcwd(), "GroundingDINO"))
sys.path.append(os.path.join(os.getcwd(), "segment_anything"))
# os.system("wget https://huggingface.co/ShilongLiu/GroundingDINO/resolve/main/groundingdino_swint_ogc.pth")
# os.system("wget https://huggingface.co/spaces/mrtlive/segment-anything-model/resolve/main/sam_vit_h_4b8939.pth")

import gradio as gr
from dino_sam import sam_dino_vid

CSV_PATH = ""
VID_PATH = ""

def run_sam_dino(input_vid,
                 grounding_caption,
                 box_threshold,
                 text_threshold,
                 fps_processed,
                 video_options):
    csv_path, vid_path = sam_dino_vid(input_vid, grounding_caption, box_threshold, text_threshold, fps_processed, video_options)
    global CSV_PATH
    CSV_PATH = csv_path
    global VID_PATH
    VID_PATH = vid_path
    return vid_path

def vid_download():
    """
    """
    print(CSV_PATH, VID_PATH)
    return [CSV_PATH, VID_PATH]


with gr.Blocks() as demo:
    gr.HTML(
        """
            <h1 align="center" style="font-size:xxx-large">🦍 Primate Detection</h1>
        """
    )
    
    gr.HTML(
        """
            <p="left"> 
                The csv contains frame numbers and timestamps, bounding box coordinates, and number of detections per frame.</p>
        """
    )
    with gr.Row():
        with gr.Column():
            input = gr.Video(label="Input Video", interactive=True)
            grounding_caption = gr.Textbox(label="What do you want to detect?")
            with gr.Accordion("Advanced Options", open=False):
                box_threshold = gr.Slider(
                    label="Box Threshold",
                    info="Adjust the threshold to change the sensitivity of the model, lower thresholds being more sensitive.",
                    minimum=0.0,
                    maximum=1.0,
                    value=0.25,
                    step=0.01
                )
                text_threshold = gr.Slider(
                    label="Text Threshold", minimum=0.0, maximum=1.0, value=0.25, step=0.01
                )
                fps_processed = gr.Slider(
                    label="Frame Detection Rate",
                    info="Adjust the frame detection rate. I.e. a value of 120 will run detection every 120 frames, a value of 1 will run detection on every frame. Note: the lower the number the slower the processing time.",
                    minimum=1,
                    maximum=120,
                    value=3, 
                    step=1)
                video_options = gr.CheckboxGroup(choices=["Bounding boxes", "Masks"],
                                                 label="Video Output Options",
                                                 info="Select the options to display in the output video.",
                                                 value=["Bounding boxes"],
                                                 interactive=True)

            # TODO: Make button visible only after a file has been uploaded
            run_btn = gr.Button(value="Run Detection", visible=True)
        with gr.Column():
            vid = gr.Video(label="Output Video", height=350, interactive=False, visible=True)
            # download_btn = gr.Button(value="Generate Download", visible=True)
            download_file = gr.Files(label="CSV, Video Output", interactive=False)
    
    run_btn.click(fn=run_sam_dino, inputs=[input, grounding_caption, box_threshold, text_threshold, fps_processed, video_options], outputs=[vid])
    vid.change(fn=vid_download, outputs=download_file)

    gr.Examples(
        [["baboon_15s.mp4", "baboon", 0.25, 0.25, 1, ["Bounding boxes", "Masks"]]],
        inputs = [input, grounding_caption, box_threshold, text_threshold, fps_processed, video_options],
        outputs = [vid],
        fn=run_sam_dino,
        cache_examples=True,
        label='Example'
      )

demo.launch(share=False)