Spaces:
Runtime error
Runtime error
import os | |
import av | |
import cv2 | |
import numpy as np | |
import torch | |
import gradio as gr | |
from transformers import AutoProcessor, TvpForVideoGrounding | |
def pyav_decode(container, sampling_rate, num_frames, clip_idx, num_clips, target_fps): | |
''' | |
Convert the video from its original fps to the target_fps and decode the video with PyAV decoder. | |
Args: | |
container (container): pyav container. | |
sampling_rate (int): frame sampling rate (interval between two sampled frames). | |
num_frames (int): number of frames to sample. | |
clip_idx (int): if clip_idx is -1, perform random temporal sampling. | |
If clip_idx is larger than -1, uniformly split the video to num_clips | |
clips, and select the clip_idx-th video clip. | |
num_clips (int): overall number of clips to uniformly sample from the given video. | |
target_fps (int): the input video may have different fps, convert it to | |
the target video fps before frame sampling. | |
Returns: | |
frames (tensor): decoded frames from the video. Return None if the no | |
video stream was found. | |
fps (float): the number of frames per second of the video. | |
''' | |
video = container.streams.video[0] | |
fps = float(video.average_rate) | |
clip_size = sampling_rate * num_frames / target_fps * fps | |
delta = max(num_frames - clip_size, 0) | |
start_idx = delta * clip_idx / num_clips | |
end_idx = start_idx + clip_size - 1 | |
timebase = video.duration / num_frames | |
video_start_pts = int(start_idx * timebase) | |
video_end_pts = int(end_idx * timebase) | |
seek_offset = max(video_start_pts - 1024, 0) | |
container.seek(seek_offset, any_frame=False, backward=True, stream=video) | |
frames = {} | |
for frame in container.decode(video=0): | |
if frame.pts < video_start_pts: | |
continue | |
frames[frame.pts] = frame | |
if frame.pts > video_end_pts: | |
break | |
frames = [frames[pts] for pts in sorted(frames)] | |
return frames, fps | |
def decode(container, sampling_rate, num_frames, clip_idx, num_clips, target_fps): | |
''' | |
Decode the video and perform temporal sampling. | |
Args: | |
container (container): pyav container. | |
sampling_rate (int): frame sampling rate (interval between two sampled frames). | |
num_frames (int): number of frames to sample. | |
clip_idx (int): if clip_idx is -1, perform random temporal sampling. | |
If clip_idx is larger than -1, uniformly split the video to num_clips | |
clips, and select the clip_idx-th video clip. | |
num_clips (int): overall number of clips to uniformly sample from the given video. | |
target_fps (int): the input video may have different fps, convert it to | |
the target video fps before frame sampling. | |
Returns: | |
frames (tensor): decoded frames from the video. | |
''' | |
assert clip_idx >= -2, "Not a valied clip_idx {}".format(clip_idx) | |
frames, fps = pyav_decode(container, sampling_rate, num_frames, clip_idx, num_clips, target_fps) | |
clip_size = sampling_rate * num_frames / target_fps * fps | |
index = np.linspace(0, clip_size - 1, num_frames) | |
index = np.clip(index, 0, len(frames) - 1).astype(np.int64) | |
frames = np.array([frames[idx].to_rgb().to_ndarray() for idx in index]) | |
frames = frames.transpose(0, 3, 1, 2) | |
return frames | |
def get_video_duration(filename): | |
cap = cv2.VideoCapture(_extract_video_filepath(filename)) | |
if cap.isOpened(): | |
rate = cap.get(5) | |
frame_num = cap.get(7) | |
duration = frame_num/rate | |
return duration | |
return -1 | |
def _extract_video_filepath(video_filename): | |
if isinstance(video_filename, dict): | |
return video_filename['video']['path'] | |
return video_filename | |
def predict_durations(model_checkpoint, text, video_filename, device="cpu"): | |
print(f"Loading model: {model_checkpoint}") | |
model = TvpForVideoGrounding.from_pretrained(model_checkpoint) | |
processor = AutoProcessor.from_pretrained(model_checkpoint) | |
print(f"Loading video: {video_filename}") | |
filepath = video_filename['video']['path'] if isinstance(video_filename, dict) else video_filename | |
raw_sampled_frames = decode( | |
container=av.open(_extract_video_filepath(video_filename), metadata_errors="ignore"), | |
# container=av.open(video_filename['path'], metadata_errors="ignore"), | |
sampling_rate=1, | |
num_frames=model.config.num_frames, | |
clip_idx=0, | |
num_clips=1, | |
target_fps=3, | |
) | |
print("Processing video and text") | |
model_inputs = processor( | |
text=[text], videos=list(raw_sampled_frames), return_tensors="pt", max_text_length=100 | |
).to(device) | |
# model_inputs["pixel_values"] = model_inputs["pixel_values"].to(model.dtype) | |
print("Running inference") | |
output = model(**model_inputs) | |
duration = get_video_duration(video_filename) | |
start, end = processor.post_process_video_grounding(output.logits, duration) | |
return f"start: {start}s, end: {end}s" | |
HF_TOKEN = os.environ.get("HF_TOKEN", None) | |
DEVICE = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu') | |
MODELS = ["Intel/tvp-base", "Intel/tvp-base-ANet"] | |
EXAMPLES = [ | |
["Intel/tvp-base", "a person is sitting on a bed.", "./examples/bed.mp4", ], | |
["Intel/tvp-base", "a person eats some food.", "./examples/food.mp4", ], | |
["Intel/tvp-base", "a person reads a book.", "./examples/book.mp4", ], | |
] | |
title = "Video Grounding with TVP" | |
DESCRIPTION = """# Video Grounding with TVP""" | |
with gr.Blocks(title=title) as demo: | |
gr.Markdown(DESCRIPTION) | |
gr.Markdown( | |
""" | |
Video Grounding is the task of localizing a moment in a video that best matches a natural language description. | |
For example, given the video of a person sitting on a bed, the model should be able to predict the start and end time of the video that best matches the description "a person is sitting on a bed". | |
Enter a description of an event in the video and select a video to see the predicted start and end time. | |
""" | |
) | |
with gr.Row(): | |
model_checkpoint = gr.Dropdown(MODELS, label="Model", value=MODELS[0], type="value") | |
with gr.Row(equal_height=True): | |
with gr.Column(scale=0.5): | |
video_in = gr.Video(label="Video File", elem_id="video_in") | |
with gr.Column(): | |
text_in = gr.Textbox(label="Text", placeholder="Description of event in the video", interactive=True) | |
text_out = gr.Textbox(label="Prediction", placeholder="Predicted start and end time") | |
time_button = gr.Button("Get start and end time") | |
time_button.click(predict_durations, inputs=[model_checkpoint, text_in, video_in], outputs=[text_out]) | |
examples = gr.Examples(examples=EXAMPLES, fn=predict_durations, inputs=[model_checkpoint, text_in, video_in], outputs=[text_out], cache_examples=True, preprocess=False) | |
if __name__ == "__main__": | |
demo.launch(debug=True) | |