tvp / app.py
Amy Roberts
Finalise
f2b92aa
raw history blame
No virus
6.09 kB
import os
import av
import cv2
import numpy as np
import torch
import gradio as gr
from transformers import AutoProcessor, TvpForVideoGrounding
def pyav_decode(container, sampling_rate, num_frames, clip_idx, num_clips, target_fps):
'''
Convert the video from its original fps to the target_fps and decode the video with PyAV decoder.
Args:
container (container): pyav container.
sampling_rate (int): frame sampling rate (interval between two sampled frames).
num_frames (int): number of frames to sample.
clip_idx (int): if clip_idx is -1, perform random temporal sampling.
If clip_idx is larger than -1, uniformly split the video to num_clips
clips, and select the clip_idx-th video clip.
num_clips (int): overall number of clips to uniformly sample from the given video.
target_fps (int): the input video may have different fps, convert it to
the target video fps before frame sampling.
Returns:
frames (tensor): decoded frames from the video. Return None if the no
video stream was found.
fps (float): the number of frames per second of the video.
'''
video = container.streams.video[0]
fps = float(video.average_rate)
clip_size = sampling_rate * num_frames / target_fps * fps
delta = max(num_frames - clip_size, 0)
start_idx = delta * clip_idx / num_clips
end_idx = start_idx + clip_size - 1
timebase = video.duration / num_frames
video_start_pts = int(start_idx * timebase)
video_end_pts = int(end_idx * timebase)
seek_offset = max(video_start_pts - 1024, 0)
container.seek(seek_offset, any_frame=False, backward=True, stream=video)
frames = {}
for frame in container.decode(video=0):
if frame.pts < video_start_pts:
continue
frames[frame.pts] = frame
if frame.pts > video_end_pts:
break
frames = [frames[pts] for pts in sorted(frames)]
return frames, fps
def decode(container, sampling_rate, num_frames, clip_idx, num_clips, target_fps):
'''
Decode the video and perform temporal sampling.
Args:
container (container): pyav container.
sampling_rate (int): frame sampling rate (interval between two sampled frames).
num_frames (int): number of frames to sample.
clip_idx (int): if clip_idx is -1, perform random temporal sampling.
If clip_idx is larger than -1, uniformly split the video to num_clips
clips, and select the clip_idx-th video clip.
num_clips (int): overall number of clips to uniformly sample from the given video.
target_fps (int): the input video may have different fps, convert it to
the target video fps before frame sampling.
Returns:
frames (tensor): decoded frames from the video.
'''
assert clip_idx >= -2, "Not a valied clip_idx {}".format(clip_idx)
frames, fps = pyav_decode(container, sampling_rate, num_frames, clip_idx, num_clips, target_fps)
clip_size = sampling_rate * num_frames / target_fps * fps
index = np.linspace(0, clip_size - 1, num_frames)
index = np.clip(index, 0, len(frames) - 1).astype(np.int64)
frames = np.array([frames[idx].to_rgb().to_ndarray() for idx in index])
frames = frames.transpose(0, 3, 1, 2)
return frames
def get_video_duration(filename):
cap = cv2.VideoCapture(filename)
if cap.isOpened():
rate = cap.get(5)
frame_num = cap.get(7)
duration = frame_num/rate
return duration
return -1
def predict_durations(model_checkpoint, text, video_filename, device="cpu"):
print(f"Loading model: {model_checkpoint}")
model = TvpForVideoGrounding.from_pretrained(model_checkpoint)
processor = AutoProcessor.from_pretrained(model_checkpoint)
print(f"Loading video: {video_filename}")
raw_sampled_frames = decode(
container=av.open(video_filename, metadata_errors="ignore"),
sampling_rate=1,
num_frames=model.config.num_frames,
clip_idx=0,
num_clips=1,
target_fps=3,
)
print("Processing video and text")
model_inputs = processor(
text=[text], videos=list(raw_sampled_frames), return_tensors="pt", max_text_length=100
).to(device)
# model_inputs["pixel_values"] = model_inputs["pixel_values"].to(model.dtype)
print("Running inference")
output = model(**model_inputs)
duration = get_video_duration(video_filename)
start, end = processor.post_process_video_grounding(output.logits, duration)
return f"start: {start}s, end: {end}s"
HF_TOKEN = os.environ.get("HF_TOKEN", None)
DEVICE = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
MODELS = ["Intel/tvp-base", "Intel/tvp-base-ANet"]
EXAMPLES = [
["./examples/bed.mp4", "a person is sitting on a bed."],
["./examples/food.mp4", "a person eats some food."],
["./examples/book.mp4", "a person reads a book."],
]
model_checkpoint = gr.Dropdown(MODELS, label="Model", value=MODELS[0], type="value")
video_in = gr.Video(label="Video File", elem_id="video_in")
text_in = gr.Textbox(label="Text", placeholder="Description of event in the video", interactive=True)
text_out = gr.Textbox(label="Prediction", placeholder="Predicted start and end time")
title = "Video Grounding with TVP"
DESCRIPTION = """# Video Grounding with TVP"""
css = """.toast-wrap { display: none !important } """
with gr.Blocks(title=title) as demo:
gr.Markdown(DESCRIPTION)
with gr.Row():
model_checkpoint.render()
with gr.Row():
examples = gr.Examples(examples=EXAMPLES, inputs=[video_in, text_in])
with gr.Row():
with gr.Column():
video_in.render()
with gr.Column():
text_in.render()
time_button = gr.Button("Get start and end time")
time_button.click(predict_durations, inputs=[model_checkpoint, text_in, video_in], outputs=[text_out])
text_out.render()
if __name__ == "__main__":
demo.launch(debug=True)