Spaces:
Runtime error
Runtime error
import io | |
import os | |
import cv2 | |
import json | |
import numpy as np | |
from PIL import Image | |
from tqdm import tqdm | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torchvision.transforms as transforms | |
from vbench2_beta_i2v.utils import load_video, load_i2v_dimension_info, dino_transform, dino_transform_Image | |
import logging | |
logging.basicConfig(level = logging.INFO,format = '%(asctime)s - %(name)s - %(levelname)s - %(message)s') | |
logger = logging.getLogger(__name__) | |
def i2v_subject(model, video_pair_list, device): | |
video_results = [] | |
sim_list = [] | |
max_weight = 0.5 | |
mean_weight = 0.5 | |
min_weight = 0.0 | |
image_transform = dino_transform_Image(224) | |
frames_transform = dino_transform(224) | |
for image_path, video_path in tqdm(video_pair_list): | |
# input image preprocess & extract feature | |
input_image = image_transform(Image.open(image_path)) | |
input_image = input_image.unsqueeze(0) | |
input_image = input_image.to(device) | |
input_image_features = model(input_image) | |
input_image_features = F.normalize(input_image_features, dim=-1, p=2) | |
# get frames from video | |
images = load_video(video_path) | |
images = frames_transform(images) | |
# calculate sim between input image and frames in generated video | |
conformity_scores = [] | |
consec_scores = [] | |
for i in range(len(images)): | |
with torch.no_grad(): | |
image = images[i].unsqueeze(0) | |
image = image.to(device) | |
image_features = model(image) | |
image_features = F.normalize(image_features, dim=-1, p=2) | |
if i != 0: | |
sim_consec = max(0.0, F.cosine_similarity(former_image_features, image_features).item()) | |
consec_scores.append(sim_consec) | |
sim_to_input = max(0.0, F.cosine_similarity(input_image_features, image_features).item()) | |
conformity_scores.append(sim_to_input) | |
former_image_features = image_features | |
video_score = max_weight * np.max(conformity_scores) + \ | |
mean_weight * np.mean(consec_scores) + \ | |
min_weight * np.min(consec_scores) | |
sim_list.append(video_score) | |
video_results.append({'image_path': image_path, 'video_path': video_path, 'video_results': video_score}) | |
return np.mean(sim_list), video_results | |
def compute_i2v_subject(json_dir, device, submodules_list): | |
dino_model = torch.hub.load(**submodules_list).to(device) | |
resolution = submodules_list['resolution'] | |
logger.info("Initialize DINO success") | |
video_pair_list, _ = load_i2v_dimension_info(json_dir, dimension='i2v_subject', lang='en', resolution=resolution) | |
all_results, video_results = i2v_subject(dino_model, video_pair_list, device) | |
return all_results, video_results | |