Spaces:
Running
on
A10G
Running
on
A10G
import subprocess | |
import os | |
import sys | |
# set CUDA_HOME | |
os.environ["CUDA_HOME"] = "/usr/local/cuda-11.8/" | |
subprocess.run(['pip', 'install', '-e', 'GroundingDINO']) | |
sys.path.append(os.path.join(os.getcwd(), "GroundingDINO")) | |
sys.path.append(os.path.join(os.getcwd(), "segment_anything")) | |
# os.system("wget https://huggingface.co/ShilongLiu/GroundingDINO/resolve/main/groundingdino_swint_ogc.pth") | |
# os.system("wget https://huggingface.co/spaces/mrtlive/segment-anything-model/resolve/main/sam_vit_h_4b8939.pth") | |
import gradio as gr | |
from dino_sam import sam_dino_vid | |
CSV_PATH = "" | |
VID_PATH = "" | |
def run_sam_dino(input_vid, | |
grounding_caption, | |
box_threshold, | |
text_threshold, | |
fps_processed, | |
video_options): | |
csv_path, vid_path = sam_dino_vid(input_vid, grounding_caption, box_threshold, text_threshold, fps_processed, video_options) | |
global CSV_PATH | |
CSV_PATH = csv_path | |
global VID_PATH | |
VID_PATH = vid_path | |
return vid_path | |
def vid_download(): | |
""" | |
""" | |
print(CSV_PATH, VID_PATH) | |
return [CSV_PATH, VID_PATH] | |
with gr.Blocks() as demo: | |
gr.HTML( | |
""" | |
<h1 align="center" style="font-size:xxx-large">🦍 Primate Detection</h1> | |
""" | |
) | |
gr.HTML( | |
""" | |
<p="left"> | |
The csv contains frame numbers and timestamps, bounding box coordinates, and number of detections per frame.</p> | |
""" | |
) | |
with gr.Row(): | |
with gr.Column(): | |
input = gr.Video(label="Input Video", interactive=True) | |
grounding_caption = gr.Textbox(label="What do you want to detect?") | |
with gr.Accordion("Advanced Options", open=False): | |
box_threshold = gr.Slider( | |
label="Box Threshold", | |
info="Adjust the threshold to change the sensitivity of the model, lower thresholds being more sensitive.", | |
minimum=0.0, | |
maximum=1.0, | |
value=0.25, | |
step=0.01 | |
) | |
text_threshold = gr.Slider( | |
label="Text Threshold", minimum=0.0, maximum=1.0, value=0.25, step=0.01 | |
) | |
fps_processed = gr.Slider( | |
label="Frame Detection Rate", | |
info="Adjust the frame detection rate. I.e. a value of 120 will run detection every 120 frames, a value of 1 will run detection on every frame. Note: the lower the number the slower the processing time.", | |
minimum=1, | |
maximum=120, | |
value=3, | |
step=1) | |
video_options = gr.CheckboxGroup(choices=["Bounding boxes", "Masks"], | |
label="Video Output Options", | |
info="Select the options to display in the output video.", | |
value=["Bounding boxes"], | |
interactive=True) | |
# TODO: Make button visible only after a file has been uploaded | |
run_btn = gr.Button(value="Run Detection", visible=True) | |
with gr.Column(): | |
vid = gr.Video(label="Output Video", height=350, interactive=False, visible=True) | |
# download_btn = gr.Button(value="Generate Download", visible=True) | |
download_file = gr.Files(label="CSV, Video Output", interactive=False) | |
run_btn.click(fn=run_sam_dino, inputs=[input, grounding_caption, box_threshold, text_threshold, fps_processed, video_options], outputs=[vid]) | |
vid.change(fn=vid_download, outputs=download_file) | |
gr.Examples( | |
[["baboon_15s.mp4", "baboon", 0.25, 0.25, 1, ["Bounding boxes", "Masks"]]], | |
inputs = [input, grounding_caption, box_threshold, text_threshold, fps_processed, video_options], | |
outputs = [vid], | |
fn=run_sam_dino, | |
cache_examples=True, | |
label='Example' | |
) | |
demo.launch(share=False) |