File size: 6,087 Bytes
f2b92aa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
import os

import av
import cv2
import numpy as np
import torch
import gradio as gr
from transformers import AutoProcessor, TvpForVideoGrounding



def pyav_decode(container, sampling_rate, num_frames, clip_idx, num_clips, target_fps):
    '''
    Convert the video from its original fps to the target_fps and decode the video with PyAV decoder.
    Args:
        container (container): pyav container.
        sampling_rate (int): frame sampling rate (interval between two sampled frames).
        num_frames (int): number of frames to sample.
        clip_idx (int): if clip_idx is -1, perform random temporal sampling.
            If clip_idx is larger than -1, uniformly split the video to num_clips
            clips, and select the clip_idx-th video clip.
        num_clips (int): overall number of clips to uniformly sample from the given video.
        target_fps (int): the input video may have different fps, convert it to
            the target video fps before frame sampling.
    Returns:
        frames (tensor): decoded frames from the video. Return None if the no
            video stream was found.
        fps (float): the number of frames per second of the video.
    '''
    video = container.streams.video[0]
    fps = float(video.average_rate)
    clip_size = sampling_rate * num_frames / target_fps * fps
    delta = max(num_frames - clip_size, 0)
    start_idx = delta * clip_idx / num_clips
    end_idx = start_idx + clip_size - 1
    timebase = video.duration / num_frames
    video_start_pts = int(start_idx * timebase)
    video_end_pts = int(end_idx * timebase)
    seek_offset = max(video_start_pts - 1024, 0)
    container.seek(seek_offset, any_frame=False, backward=True, stream=video)
    frames = {}
    for frame in container.decode(video=0):
        if frame.pts < video_start_pts:
            continue
        frames[frame.pts] = frame
        if frame.pts > video_end_pts:
            break
    frames = [frames[pts] for pts in sorted(frames)]
    return frames, fps


def decode(container, sampling_rate, num_frames, clip_idx, num_clips, target_fps):
    '''
    Decode the video and perform temporal sampling.
    Args:
        container (container): pyav container.
        sampling_rate (int): frame sampling rate (interval between two sampled frames).
        num_frames (int): number of frames to sample.
        clip_idx (int): if clip_idx is -1, perform random temporal sampling.
            If clip_idx is larger than -1, uniformly split the video to num_clips
            clips, and select the clip_idx-th video clip.
        num_clips (int): overall number of clips to uniformly sample from the given video.
        target_fps (int): the input video may have different fps, convert it to
            the target video fps before frame sampling.
    Returns:
        frames (tensor): decoded frames from the video.
    '''
    assert clip_idx >= -2, "Not a valied clip_idx {}".format(clip_idx)
    frames, fps = pyav_decode(container, sampling_rate, num_frames, clip_idx, num_clips, target_fps)
    clip_size = sampling_rate * num_frames / target_fps * fps
    index = np.linspace(0, clip_size - 1, num_frames)
    index = np.clip(index, 0, len(frames) - 1).astype(np.int64)
    frames = np.array([frames[idx].to_rgb().to_ndarray() for idx in index])
    frames = frames.transpose(0, 3, 1, 2)
    return frames


def get_video_duration(filename):
    cap = cv2.VideoCapture(filename)
    if cap.isOpened():
        rate = cap.get(5)
        frame_num = cap.get(7)
        duration = frame_num/rate
        return duration
    return -1


def predict_durations(model_checkpoint, text, video_filename, device="cpu"):
    print(f"Loading model: {model_checkpoint}")
    model = TvpForVideoGrounding.from_pretrained(model_checkpoint)
    processor = AutoProcessor.from_pretrained(model_checkpoint)
    print(f"Loading video: {video_filename}")
    raw_sampled_frames = decode(
        container=av.open(video_filename, metadata_errors="ignore"),
        sampling_rate=1,
        num_frames=model.config.num_frames,
        clip_idx=0,
        num_clips=1,
        target_fps=3,
    )
    print("Processing video and text")
    model_inputs = processor(
        text=[text], videos=list(raw_sampled_frames), return_tensors="pt", max_text_length=100
    ).to(device)
    # model_inputs["pixel_values"] = model_inputs["pixel_values"].to(model.dtype)
    print("Running inference")
    output = model(**model_inputs)
    duration = get_video_duration(video_filename)
    start, end = processor.post_process_video_grounding(output.logits, duration)
    return f"start: {start}s, end: {end}s"


HF_TOKEN = os.environ.get("HF_TOKEN", None)
DEVICE = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
MODELS = ["Intel/tvp-base",  "Intel/tvp-base-ANet"]
EXAMPLES = [
    ["./examples/bed.mp4", "a person is sitting on a bed."],
    ["./examples/food.mp4", "a person eats some food."],
    ["./examples/book.mp4", "a person reads a book."],
]

model_checkpoint = gr.Dropdown(MODELS, label="Model", value=MODELS[0], type="value")
video_in = gr.Video(label="Video File", elem_id="video_in")
text_in = gr.Textbox(label="Text", placeholder="Description of event in the video", interactive=True)
text_out = gr.Textbox(label="Prediction", placeholder="Predicted start and end time")


title = "Video Grounding with TVP"
DESCRIPTION = """# Video Grounding with TVP"""
css = """.toast-wrap { display: none !important } """
with gr.Blocks(title=title) as demo:
    gr.Markdown(DESCRIPTION)
    with gr.Row():
        model_checkpoint.render()

    with gr.Row():
        examples = gr.Examples(examples=EXAMPLES, inputs=[video_in, text_in])

    with gr.Row():
        with gr.Column():
            video_in.render()

        with gr.Column():
            text_in.render()
            time_button = gr.Button("Get start and end time")
            time_button.click(predict_durations, inputs=[model_checkpoint, text_in, video_in], outputs=[text_out])
            text_out.render()


if __name__ == "__main__":
    demo.launch(debug=True)