ynhe's picture
[Init] upload model
ff495b4 verified
import os
import random
import io
import av
import cv2
import decord
import imageio
from decord import VideoReader
import torch
import numpy as np
import math
import torch.nn.functional as F
decord.bridge.set_bridge("torch")
from transformers import AutoConfig, AutoModel
config = AutoConfig.from_pretrained("/fs-computility/video/heyinan/iv2hf/", trust_remote_code=True)
model = AutoModel.from_pretrained("/fs-computility/video/heyinan/iv2hf/", trust_remote_code=True).to(config.device)
def get_frame_indices(num_frames, vlen, sample='rand', fix_start=None, input_fps=1, max_num_frames=-1, start=None, end=None):
start_frame, end_frame = 0, vlen
if start is not None:
start_frame = max(start_frame,int(start * input_fps))
if end is not None:
end_frame = min(end_frame,int(end * input_fps))
# Ensure start_frame is less than end_frame
if start_frame >= end_frame:
raise ValueError("Start frame index must be less than end frame index")
# Calculate the length of the clip in frames
clip_length = end_frame - start_frame
if sample in ["rand", "middle"]: # uniform sampling
acc_samples = min(num_frames, clip_length)
# split the clip into `acc_samples` intervals, and sample from each interval.
intervals = np.linspace(start=start_frame, stop=end_frame, num=acc_samples + 1).astype(int)
ranges = []
for idx, interv in enumerate(intervals[:-1]):
ranges.append((interv, intervals[idx + 1] - 1))
if sample == 'rand':
try:
frame_indices = [random.choice(range(x[0], x[1] + 1)) for x in ranges]
except:
frame_indices = np.random.permutation(clip_length)[:acc_samples] + start_frame
frame_indices.sort()
frame_indices = list(frame_indices)
elif fix_start is not None:
frame_indices = [x[0] + fix_start for x in ranges]
elif sample == 'middle':
frame_indices = [(x[0] + x[1]) // 2 for x in ranges]
else:
raise NotImplementedError
if len(frame_indices) < num_frames: # padded with last frame
padded_frame_indices = [frame_indices[-1]] * num_frames
padded_frame_indices[:len(frame_indices)] = frame_indices
frame_indices = padded_frame_indices
elif "fps" in sample: # fps0.5, sequentially sample frames at 0.5 fps
output_fps = float(sample[3:])
duration = float(clip_length) / input_fps
delta = 1 / output_fps # gap between frames, this is also the clip length each frame represents
frame_seconds = np.arange(0 + delta / 2, duration + delta / 2, delta)
frame_indices = np.around(frame_seconds * input_fps).astype(int) + start_frame
frame_indices = [e for e in frame_indices if e < end_frame]
if max_num_frames > 0 and len(frame_indices) > max_num_frames:
frame_indices = frame_indices[:max_num_frames]
# frame_indices = np.linspace(0 + delta / 2, duration + delta / 2, endpoint=False, num=max_num_frames)
else:
raise ValueError
return frame_indices
def read_frames_decord(
video_path, num_frames, sample='middle', fix_start=None,
max_num_frames=-1, client=None, trimmed30=False, start=None, end=None
):
num_threads = 1 if video_path.endswith('.webm') else 0 # make ssv2 happy
video_reader = VideoReader(video_path, num_threads=num_threads)
vlen = len(video_reader)
fps = video_reader.get_avg_fps()
duration = vlen / float(fps)
frame_indices = get_frame_indices(
num_frames, vlen, sample=sample, fix_start=fix_start,
input_fps=fps, max_num_frames=max_num_frames, start=start, end=end
)
frames = video_reader.get_batch(frame_indices) # (T, H, W, C), torch.uint8
frames = frames.permute(0, 3, 1, 2) # (T, C, H, W), torch.uint8
return frames, frame_indices, duration
def get_text_feature(model, texts):
text_input = model.tokenizer(texts).to(model.device)
text_features = model.encode_text(text_input)
return text_features
def get_similarity(video_feature, text_feature):
video_feature = F.normalize(video_feature, dim=-1)
text_feature = F.normalize(text_feature, dim=-1)
sim_matrix = text_feature @ video_feature.T
return sim_matrix
def get_top_videos(model, text_features, video_features, video_paths, texts):
# text_features = get_text_feature(texts)
video_features = F.normalize(video_features, dim=-1)
text_features = F.normalize(text_features, dim=-1)
# print(text_features.shape, video_features.shape)
sim_matrix = text_features @ video_features.T
# print(sim_matrix.shape)
top_k = 5
sim_matrix_top_k = torch.topk(sim_matrix, top_k, dim=1)[1]
softmax_sim_matrix = F.softmax(sim_matrix, dim=1)
retrieval_infos = {}
for i in range(len(sim_matrix_top_k)):
print("\n",texts[i])
retrieval_infos[texts[i]] = []
for j in range(top_k):
print("top", j+1, ":", video_paths[sim_matrix_top_k[i][j]], "~prob:", sim_matrix[i][sim_matrix_top_k[i][j]].item())
retrieval_infos[texts[i]].append({"video": video_paths[sim_matrix_top_k[i][j]], "prob": sim_matrix[i][sim_matrix_top_k[i][j]].item(), "rank": j+1})
return retrieval_infos
if __name__=="__main__":
video_features = []
demo_videos = ["video1.mp4","video2.mp4"]
texts = ['a person talking', 'a logo', 'a building']
for video_path in demo_videos:
frames, frame_indices, video_duration = read_frames_decord(video_path,8)
frames = model.transform(frames).unsqueeze(0).to(model.device)
with torch.no_grad():
video_feature = model.encode_vision(frames, test=True)
video_features.append(video_feature)
text_features = get_text_feature(model, texts)
video_features = torch.cat(video_features, dim=0).to(text_features.dtype).to(config.device)
results = get_top_videos(model, text_features, video_features, demo_videos, texts)