File size: 3,901 Bytes
0ae17b9
 
 
 
 
 
 
 
 
 
 
 
3db7352
0ae17b9
 
 
 
 
 
 
 
 
 
120b633
0ae17b9
 
 
 
5431f3e
ede73d2
0ae17b9
 
 
 
 
 
 
 
 
 
 
 
4a7e314
 
0ae17b9
 
b942818
0ae17b9
 
 
 
 
 
 
 
 
 
 
 
 
 
b942818
0ae17b9
 
 
 
120b633
80811d8
 
0ae17b9
 
 
 
 
 
 
4a7e314
 
b942818
0ae17b9
 
 
d957709
0ae17b9
b942818
0ae17b9
 
 
 
6e6de0a
d1081fe
0ae17b9
 
 
 
 
b942818
0ae17b9
 
 
 
 
 
 
9b265f1
0ae17b9
b942818
 
0ae17b9
 
 
 
 
 
b942818
0ae17b9
 
 
4a7e314
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
from typing import List
import os
import numpy as np
import supervision as sv
import uuid
import torch
from tqdm import tqdm
import gradio as gr
import torch
import numpy as np
from PIL import Image
from transformers import AutoImageProcessor, AutoModelForObjectDetection
import spaces

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

processor = AutoImageProcessor.from_pretrained("PekingU/rtdetr_r50vd_coco_o365")
model = AutoModelForObjectDetection.from_pretrained("PekingU/rtdetr_r50vd_coco_o365").to(device)


BOUNDING_BOX_ANNOTATOR = sv.BoundingBoxAnnotator()
MASK_ANNOTATOR = sv.MaskAnnotator()
LABEL_ANNOTATOR = sv.LabelAnnotator()
TRACKER = sv.ByteTrack()


def calculate_end_frame_index(source_video_path):
    video_info = sv.VideoInfo.from_video_path(source_video_path)
    return min(video_info.total_frames, video_info.fps * 2)
        


def annotate_image(
    input_image,
    detections,
    labels
) -> np.ndarray:
    output_image = MASK_ANNOTATOR.annotate(input_image, detections)
    output_image = BOUNDING_BOX_ANNOTATOR.annotate(output_image, detections)
    output_image = LABEL_ANNOTATOR.annotate(output_image, detections, labels=labels)
    return output_image


@spaces.GPU
def process_video(
    input_video,
    confidence_threshold,
    progress=gr.Progress(track_tqdm=True)
):
    video_info = sv.VideoInfo.from_video_path(input_video)
    total = calculate_end_frame_index(input_video)
    frame_generator = sv.get_video_frames_generator(
        source_path=input_video,
        end=total
    )

    result_file_name = f"{uuid.uuid4()}.mp4"
    result_file_path = os.path.join("./", result_file_name)
    with sv.VideoSink(result_file_path, video_info=video_info) as sink:
        for _ in tqdm(range(total), desc="Processing video.."):
            frame = next(frame_generator)
            results = query(Image.fromarray(frame), confidence_threshold)
            final_labels = []
            detections = []
            
            detections = sv.Detections.from_transformers(results[0])
            detections = TRACKER.update_with_detections(detections)
            for label in detections.class_id.tolist():
                final_labels.append(model.config.id2label[label])
            frame = annotate_image(
                input_image=frame,
                detections=detections,
                labels=final_labels,
            )
            sink.write_frame(frame)
    return result_file_path
    

def query(image, confidence_threshold):
  inputs = processor(images=image, return_tensors="pt").to(device)
  with torch.no_grad():
    outputs = model(**inputs)
  target_sizes = torch.tensor([image.size[::-1]])
  
  results = processor.post_process_object_detection(outputs=outputs, threshold=confidence_threshold, target_sizes=target_sizes)
  return results

with gr.Blocks() as demo:
  gr.Markdown("## Real Time Object Tracking with RT-DETR")
  gr.Markdown("This is a demo for object tracking using RT-DETR. It runs on ZeroGPU which captures GPU every first time you infer, so the model is actually faster than the inference in this demo.")
  gr.Markdown("Simply upload a video, you can also play with confidence threshold, or try the example below. 👇")
  with gr.Row():
    with gr.Column():
        input_video = gr.Video(
            label='Input Video'
        )
        conf = gr.Slider(label="Confidence Threshold", minimum=0.1, maximum=1.0, value=0.6, step=0.05)
        submit = gr.Button()
    with gr.Column():
        output_video = gr.Video(
            label='Output Video'
        )
  gr.Examples(
      fn=process_video,
      examples=[["./cat.mp4", 0.6], ["./football.mp4", 0.6]],
      inputs=[
          input_video,
          conf
      ],
      outputs=output_video
  )

  submit.click(
      fn=process_video,
      inputs=[input_video, conf],
      outputs=output_video
  )

demo.launch(show_error=True)