Spaces:
Runtime error
Runtime error
| import os | |
| import av | |
| import cv2 | |
| import numpy as np | |
| import torch | |
| import gradio as gr | |
| from transformers import AutoProcessor, TvpForVideoGrounding | |
| def pyav_decode(container, sampling_rate, num_frames, clip_idx, num_clips, target_fps): | |
| ''' | |
| Convert the video from its original fps to the target_fps and decode the video with PyAV decoder. | |
| Args: | |
| container (container): pyav container. | |
| sampling_rate (int): frame sampling rate (interval between two sampled frames). | |
| num_frames (int): number of frames to sample. | |
| clip_idx (int): if clip_idx is -1, perform random temporal sampling. | |
| If clip_idx is larger than -1, uniformly split the video to num_clips | |
| clips, and select the clip_idx-th video clip. | |
| num_clips (int): overall number of clips to uniformly sample from the given video. | |
| target_fps (int): the input video may have different fps, convert it to | |
| the target video fps before frame sampling. | |
| Returns: | |
| frames (tensor): decoded frames from the video. Return None if the no | |
| video stream was found. | |
| fps (float): the number of frames per second of the video. | |
| ''' | |
| video = container.streams.video[0] | |
| fps = float(video.average_rate) | |
| clip_size = sampling_rate * num_frames / target_fps * fps | |
| delta = max(num_frames - clip_size, 0) | |
| start_idx = delta * clip_idx / num_clips | |
| end_idx = start_idx + clip_size - 1 | |
| timebase = video.duration / num_frames | |
| video_start_pts = int(start_idx * timebase) | |
| video_end_pts = int(end_idx * timebase) | |
| seek_offset = max(video_start_pts - 1024, 0) | |
| container.seek(seek_offset, any_frame=False, backward=True, stream=video) | |
| frames = {} | |
| for frame in container.decode(video=0): | |
| if frame.pts < video_start_pts: | |
| continue | |
| frames[frame.pts] = frame | |
| if frame.pts > video_end_pts: | |
| break | |
| frames = [frames[pts] for pts in sorted(frames)] | |
| return frames, fps | |
| def decode(container, sampling_rate, num_frames, clip_idx, num_clips, target_fps): | |
| ''' | |
| Decode the video and perform temporal sampling. | |
| Args: | |
| container (container): pyav container. | |
| sampling_rate (int): frame sampling rate (interval between two sampled frames). | |
| num_frames (int): number of frames to sample. | |
| clip_idx (int): if clip_idx is -1, perform random temporal sampling. | |
| If clip_idx is larger than -1, uniformly split the video to num_clips | |
| clips, and select the clip_idx-th video clip. | |
| num_clips (int): overall number of clips to uniformly sample from the given video. | |
| target_fps (int): the input video may have different fps, convert it to | |
| the target video fps before frame sampling. | |
| Returns: | |
| frames (tensor): decoded frames from the video. | |
| ''' | |
| assert clip_idx >= -2, "Not a valied clip_idx {}".format(clip_idx) | |
| frames, fps = pyav_decode(container, sampling_rate, num_frames, clip_idx, num_clips, target_fps) | |
| clip_size = sampling_rate * num_frames / target_fps * fps | |
| index = np.linspace(0, clip_size - 1, num_frames) | |
| index = np.clip(index, 0, len(frames) - 1).astype(np.int64) | |
| frames = np.array([frames[idx].to_rgb().to_ndarray() for idx in index]) | |
| frames = frames.transpose(0, 3, 1, 2) | |
| return frames | |
| def get_video_duration(filename): | |
| cap = cv2.VideoCapture(_extract_video_filepath(filename)) | |
| if cap.isOpened(): | |
| rate = cap.get(5) | |
| frame_num = cap.get(7) | |
| duration = frame_num/rate | |
| return duration | |
| return -1 | |
| def _extract_video_filepath(video_filename): | |
| if isinstance(video_filename, dict): | |
| return video_filename['video']['path'] | |
| return video_filename | |
| def predict_durations(model_checkpoint, text, video_filename, device="cpu"): | |
| print(f"Loading model: {model_checkpoint}") | |
| model = TvpForVideoGrounding.from_pretrained(model_checkpoint) | |
| processor = AutoProcessor.from_pretrained(model_checkpoint) | |
| print(f"Loading video: {video_filename}") | |
| filepath = video_filename['video']['path'] if isinstance(video_filename, dict) else video_filename | |
| raw_sampled_frames = decode( | |
| container=av.open(_extract_video_filepath(video_filename), metadata_errors="ignore"), | |
| # container=av.open(video_filename['path'], metadata_errors="ignore"), | |
| sampling_rate=1, | |
| num_frames=model.config.num_frames, | |
| clip_idx=0, | |
| num_clips=1, | |
| target_fps=3, | |
| ) | |
| print("Processing video and text") | |
| model_inputs = processor( | |
| text=[text], videos=list(raw_sampled_frames), return_tensors="pt", max_text_length=100 | |
| ).to(device) | |
| # model_inputs["pixel_values"] = model_inputs["pixel_values"].to(model.dtype) | |
| print("Running inference") | |
| output = model(**model_inputs) | |
| duration = get_video_duration(video_filename) | |
| start, end = processor.post_process_video_grounding(output.logits, duration) | |
| return f"start: {start}s, end: {end}s" | |
| HF_TOKEN = os.environ.get("HF_TOKEN", None) | |
| DEVICE = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu') | |
| MODELS = ["Intel/tvp-base", "Intel/tvp-base-ANet"] | |
| EXAMPLES = [ | |
| ["Intel/tvp-base", "a person is sitting on a bed.", "./examples/bed.mp4", ], | |
| ["Intel/tvp-base", "a person eats some food.", "./examples/food.mp4", ], | |
| ["Intel/tvp-base", "a person reads a book.", "./examples/book.mp4", ], | |
| ] | |
| title = "Video Grounding with TVP" | |
| DESCRIPTION = """# Video Grounding with TVP""" | |
| with gr.Blocks(title=title) as demo: | |
| gr.Markdown(DESCRIPTION) | |
| gr.Markdown( | |
| """ | |
| Video Grounding is the task of localizing a moment in a video that best matches a natural language description. | |
| For example, given the video of a person sitting on a bed, the model should be able to predict the start and end time of the video that best matches the description "a person is sitting on a bed". | |
| Enter a description of an event in the video and select a video to see the predicted start and end time. | |
| """ | |
| ) | |
| with gr.Row(): | |
| model_checkpoint = gr.Dropdown(MODELS, label="Model", value=MODELS[0], type="value") | |
| with gr.Row(equal_height=True): | |
| with gr.Column(scale=0.5): | |
| video_in = gr.Video(label="Video File", elem_id="video_in") | |
| with gr.Column(): | |
| text_in = gr.Textbox(label="Text", placeholder="Description of event in the video", interactive=True) | |
| text_out = gr.Textbox(label="Prediction", placeholder="Predicted start and end time") | |
| time_button = gr.Button("Get start and end time") | |
| time_button.click(predict_durations, inputs=[model_checkpoint, text_in, video_in], outputs=[text_out]) | |
| examples = gr.Examples(examples=EXAMPLES, fn=predict_durations, inputs=[model_checkpoint, text_in, video_in], outputs=[text_out], cache_examples=True, preprocess=False) | |
| if __name__ == "__main__": | |
| demo.launch(debug=True) | |