|
import os |
|
import random |
|
import io |
|
import av |
|
import cv2 |
|
import decord |
|
import imageio |
|
from decord import VideoReader |
|
import torch |
|
import numpy as np |
|
import math |
|
import torch.nn.functional as F |
|
decord.bridge.set_bridge("torch") |
|
|
|
from transformers import AutoConfig, AutoModel |
|
config = AutoConfig.from_pretrained("/fs-computility/video/heyinan/iv2hf/", trust_remote_code=True) |
|
model = AutoModel.from_pretrained("/fs-computility/video/heyinan/iv2hf/", trust_remote_code=True).to(config.device) |
|
|
|
|
|
def get_frame_indices(num_frames, vlen, sample='rand', fix_start=None, input_fps=1, max_num_frames=-1, start=None, end=None): |
|
start_frame, end_frame = 0, vlen |
|
if start is not None: |
|
start_frame = max(start_frame,int(start * input_fps)) |
|
if end is not None: |
|
end_frame = min(end_frame,int(end * input_fps)) |
|
|
|
|
|
if start_frame >= end_frame: |
|
raise ValueError("Start frame index must be less than end frame index") |
|
|
|
|
|
clip_length = end_frame - start_frame |
|
|
|
if sample in ["rand", "middle"]: |
|
acc_samples = min(num_frames, clip_length) |
|
|
|
intervals = np.linspace(start=start_frame, stop=end_frame, num=acc_samples + 1).astype(int) |
|
ranges = [] |
|
for idx, interv in enumerate(intervals[:-1]): |
|
ranges.append((interv, intervals[idx + 1] - 1)) |
|
if sample == 'rand': |
|
try: |
|
frame_indices = [random.choice(range(x[0], x[1] + 1)) for x in ranges] |
|
except: |
|
frame_indices = np.random.permutation(clip_length)[:acc_samples] + start_frame |
|
frame_indices.sort() |
|
frame_indices = list(frame_indices) |
|
elif fix_start is not None: |
|
frame_indices = [x[0] + fix_start for x in ranges] |
|
elif sample == 'middle': |
|
frame_indices = [(x[0] + x[1]) // 2 for x in ranges] |
|
else: |
|
raise NotImplementedError |
|
|
|
if len(frame_indices) < num_frames: |
|
padded_frame_indices = [frame_indices[-1]] * num_frames |
|
padded_frame_indices[:len(frame_indices)] = frame_indices |
|
frame_indices = padded_frame_indices |
|
elif "fps" in sample: |
|
output_fps = float(sample[3:]) |
|
duration = float(clip_length) / input_fps |
|
delta = 1 / output_fps |
|
frame_seconds = np.arange(0 + delta / 2, duration + delta / 2, delta) |
|
frame_indices = np.around(frame_seconds * input_fps).astype(int) + start_frame |
|
frame_indices = [e for e in frame_indices if e < end_frame] |
|
if max_num_frames > 0 and len(frame_indices) > max_num_frames: |
|
frame_indices = frame_indices[:max_num_frames] |
|
|
|
else: |
|
raise ValueError |
|
return frame_indices |
|
|
|
def read_frames_decord( |
|
video_path, num_frames, sample='middle', fix_start=None, |
|
max_num_frames=-1, client=None, trimmed30=False, start=None, end=None |
|
): |
|
num_threads = 1 if video_path.endswith('.webm') else 0 |
|
|
|
video_reader = VideoReader(video_path, num_threads=num_threads) |
|
vlen = len(video_reader) |
|
|
|
fps = video_reader.get_avg_fps() |
|
duration = vlen / float(fps) |
|
|
|
frame_indices = get_frame_indices( |
|
num_frames, vlen, sample=sample, fix_start=fix_start, |
|
input_fps=fps, max_num_frames=max_num_frames, start=start, end=end |
|
) |
|
|
|
frames = video_reader.get_batch(frame_indices) |
|
frames = frames.permute(0, 3, 1, 2) |
|
return frames, frame_indices, duration |
|
|
|
def get_text_feature(model, texts): |
|
text_input = model.tokenizer(texts).to(model.device) |
|
text_features = model.encode_text(text_input) |
|
return text_features |
|
|
|
def get_similarity(video_feature, text_feature): |
|
video_feature = F.normalize(video_feature, dim=-1) |
|
text_feature = F.normalize(text_feature, dim=-1) |
|
sim_matrix = text_feature @ video_feature.T |
|
return sim_matrix |
|
|
|
def get_top_videos(model, text_features, video_features, video_paths, texts): |
|
|
|
|
|
video_features = F.normalize(video_features, dim=-1) |
|
text_features = F.normalize(text_features, dim=-1) |
|
|
|
|
|
sim_matrix = text_features @ video_features.T |
|
|
|
|
|
top_k = 5 |
|
sim_matrix_top_k = torch.topk(sim_matrix, top_k, dim=1)[1] |
|
softmax_sim_matrix = F.softmax(sim_matrix, dim=1) |
|
|
|
retrieval_infos = {} |
|
for i in range(len(sim_matrix_top_k)): |
|
print("\n",texts[i]) |
|
retrieval_infos[texts[i]] = [] |
|
for j in range(top_k): |
|
print("top", j+1, ":", video_paths[sim_matrix_top_k[i][j]], "~prob:", sim_matrix[i][sim_matrix_top_k[i][j]].item()) |
|
retrieval_infos[texts[i]].append({"video": video_paths[sim_matrix_top_k[i][j]], "prob": sim_matrix[i][sim_matrix_top_k[i][j]].item(), "rank": j+1}) |
|
return retrieval_infos |
|
|
|
if __name__=="__main__": |
|
video_features = [] |
|
demo_videos = ["video1.mp4","video2.mp4"] |
|
texts = ['a person talking', 'a logo', 'a building'] |
|
for video_path in demo_videos: |
|
frames, frame_indices, video_duration = read_frames_decord(video_path,8) |
|
frames = model.transform(frames).unsqueeze(0).to(model.device) |
|
with torch.no_grad(): |
|
video_feature = model.encode_vision(frames, test=True) |
|
video_features.append(video_feature) |
|
|
|
text_features = get_text_feature(model, texts) |
|
video_features = torch.cat(video_features, dim=0).to(text_features.dtype).to(config.device) |
|
results = get_top_videos(model, text_features, video_features, demo_videos, texts) |
|
|
|
|
|
|