tvp / app.py
Amy Roberts
Tidy up
ac60993
raw
history blame
7 kB
import os
import av
import cv2
import numpy as np
import torch
import gradio as gr
from transformers import AutoProcessor, TvpForVideoGrounding
def pyav_decode(container, sampling_rate, num_frames, clip_idx, num_clips, target_fps):
'''
Convert the video from its original fps to the target_fps and decode the video with PyAV decoder.
Args:
container (container): pyav container.
sampling_rate (int): frame sampling rate (interval between two sampled frames).
num_frames (int): number of frames to sample.
clip_idx (int): if clip_idx is -1, perform random temporal sampling.
If clip_idx is larger than -1, uniformly split the video to num_clips
clips, and select the clip_idx-th video clip.
num_clips (int): overall number of clips to uniformly sample from the given video.
target_fps (int): the input video may have different fps, convert it to
the target video fps before frame sampling.
Returns:
frames (tensor): decoded frames from the video. Return None if the no
video stream was found.
fps (float): the number of frames per second of the video.
'''
video = container.streams.video[0]
fps = float(video.average_rate)
clip_size = sampling_rate * num_frames / target_fps * fps
delta = max(num_frames - clip_size, 0)
start_idx = delta * clip_idx / num_clips
end_idx = start_idx + clip_size - 1
timebase = video.duration / num_frames
video_start_pts = int(start_idx * timebase)
video_end_pts = int(end_idx * timebase)
seek_offset = max(video_start_pts - 1024, 0)
container.seek(seek_offset, any_frame=False, backward=True, stream=video)
frames = {}
for frame in container.decode(video=0):
if frame.pts < video_start_pts:
continue
frames[frame.pts] = frame
if frame.pts > video_end_pts:
break
frames = [frames[pts] for pts in sorted(frames)]
return frames, fps
def decode(container, sampling_rate, num_frames, clip_idx, num_clips, target_fps):
'''
Decode the video and perform temporal sampling.
Args:
container (container): pyav container.
sampling_rate (int): frame sampling rate (interval between two sampled frames).
num_frames (int): number of frames to sample.
clip_idx (int): if clip_idx is -1, perform random temporal sampling.
If clip_idx is larger than -1, uniformly split the video to num_clips
clips, and select the clip_idx-th video clip.
num_clips (int): overall number of clips to uniformly sample from the given video.
target_fps (int): the input video may have different fps, convert it to
the target video fps before frame sampling.
Returns:
frames (tensor): decoded frames from the video.
'''
assert clip_idx >= -2, "Not a valied clip_idx {}".format(clip_idx)
frames, fps = pyav_decode(container, sampling_rate, num_frames, clip_idx, num_clips, target_fps)
clip_size = sampling_rate * num_frames / target_fps * fps
index = np.linspace(0, clip_size - 1, num_frames)
index = np.clip(index, 0, len(frames) - 1).astype(np.int64)
frames = np.array([frames[idx].to_rgb().to_ndarray() for idx in index])
frames = frames.transpose(0, 3, 1, 2)
return frames
def get_video_duration(filename):
cap = cv2.VideoCapture(_extract_video_filepath(filename))
if cap.isOpened():
rate = cap.get(5)
frame_num = cap.get(7)
duration = frame_num/rate
return duration
return -1
def _extract_video_filepath(video_filename):
if isinstance(video_filename, dict):
return video_filename['video']['path']
return video_filename
def predict_durations(model_checkpoint, text, video_filename, device="cpu"):
print(f"Loading model: {model_checkpoint}")
model = TvpForVideoGrounding.from_pretrained(model_checkpoint)
processor = AutoProcessor.from_pretrained(model_checkpoint)
print(f"Loading video: {video_filename}")
filepath = video_filename['video']['path'] if isinstance(video_filename, dict) else video_filename
raw_sampled_frames = decode(
container=av.open(_extract_video_filepath(video_filename), metadata_errors="ignore"),
# container=av.open(video_filename['path'], metadata_errors="ignore"),
sampling_rate=1,
num_frames=model.config.num_frames,
clip_idx=0,
num_clips=1,
target_fps=3,
)
print("Processing video and text")
model_inputs = processor(
text=[text], videos=list(raw_sampled_frames), return_tensors="pt", max_text_length=100
).to(device)
# model_inputs["pixel_values"] = model_inputs["pixel_values"].to(model.dtype)
print("Running inference")
output = model(**model_inputs)
duration = get_video_duration(video_filename)
start, end = processor.post_process_video_grounding(output.logits, duration)
return f"start: {start}s, end: {end}s"
HF_TOKEN = os.environ.get("HF_TOKEN", None)
DEVICE = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
MODELS = ["Intel/tvp-base", "Intel/tvp-base-ANet"]
EXAMPLES = [
["Intel/tvp-base", "a person is sitting on a bed.", "./examples/bed.mp4", ],
["Intel/tvp-base", "a person eats some food.", "./examples/food.mp4", ],
["Intel/tvp-base", "a person reads a book.", "./examples/book.mp4", ],
]
title = "Video Grounding with TVP"
DESCRIPTION = """# Video Grounding with TVP"""
with gr.Blocks(title=title) as demo:
gr.Markdown(DESCRIPTION)
gr.Markdown(
"""
Video Grounding is the task of localizing a moment in a video that best matches a natural language description.
For example, given the video of a person sitting on a bed, the model should be able to predict the start and end time of the video that best matches the description "a person is sitting on a bed".
Enter a description of an event in the video and select a video to see the predicted start and end time.
"""
)
with gr.Row():
model_checkpoint = gr.Dropdown(MODELS, label="Model", value=MODELS[0], type="value")
with gr.Row(equal_height=True):
with gr.Column(scale=0.5):
video_in = gr.Video(label="Video File", elem_id="video_in")
with gr.Column():
text_in = gr.Textbox(label="Text", placeholder="Description of event in the video", interactive=True)
text_out = gr.Textbox(label="Prediction", placeholder="Predicted start and end time")
time_button = gr.Button("Get start and end time")
time_button.click(predict_durations, inputs=[model_checkpoint, text_in, video_in], outputs=[text_out])
examples = gr.Examples(examples=EXAMPLES, fn=predict_durations, inputs=[model_checkpoint, text_in, video_in], outputs=[text_out], cache_examples=True, preprocess=False)
if __name__ == "__main__":
demo.launch(debug=True)