Spaces:
Running
Running
import gradio as gr | |
#import spaces | |
import os | |
import time | |
import json | |
import numpy as np | |
import av | |
import torch | |
from PIL import Image | |
import functools | |
from transformers import AutoProcessor, Idefics2ForConditionalGeneration | |
from models.conversation import conv_templates | |
from typing import List | |
processor = AutoProcessor.from_pretrained("Mantis-VL/mantis-8b-idefics2-video-eval-refined-40k_4096_generation") | |
model = Idefics2ForConditionalGeneration.from_pretrained("Mantis-VL/mantis-8b-idefics2-video-eval-refined-40k_4096_generation", device_map="auto", torch_dtype=torch.bfloat16).eval() | |
MAX_NUM_FRAMES = 24 | |
conv_template = conv_templates["idefics_2"] | |
with open("./examples/all_subsets.json", 'r') as f: | |
examples = json.load(f) | |
for item in examples: | |
video_id = item['images'][0].split("_")[0] | |
item['images'] = [os.path.join("./examples", video_id, x) for x in item['images']] | |
item['video'] = os.path.join("./examples", item['video']) | |
with open("./examples/hd.json", 'r') as f: | |
hd_examples = json.load(f) | |
for item in hd_examples: | |
item['video'] = os.path.join("./examples", item['video']) | |
examples = hd_examples + examples | |
VIDEO_EVAL_PROMPT = """ | |
Suppose you are an expert in judging and evaluating the quality of AI-generated videos, | |
please watch the following frames of a given video and see the text prompt for generating the video, | |
then give scores from 5 different dimensions: | |
(1) visual quality: the quality of the video in terms of clearness, resolution, brightness, and color | |
(2) temporal consistency, the consistency of objects or humans in video | |
(3) dynamic degree, the degree of dynamic changes | |
(4) text-to-video alignment, the alignment between the text prompt and the video content | |
(5) factual consistency, the consistency of the video content with the common-sense and factual knowledge | |
For each dimension, output a number from [1,2,3,4], | |
in which '1' means 'Bad', '2' means 'Average', '3' means 'Good', | |
'4' means 'Real' or 'Perfect' (the video is like a real video) | |
Here is an output example: | |
visual quality: 4 | |
temporal consistency: 4 | |
dynamic degree: 3 | |
text-to-video alignment: 1 | |
factual consistency: 2 | |
For this video, the text prompt is "{text_prompt}", | |
all the frames of video are as follows: | |
""" | |
#@spaces.GPU(duration=60) | |
def generate(text:str, images:List[Image.Image], history: List[dict], **kwargs): | |
model.to("cuda") | |
if not images: | |
images = None | |
user_role = conv_template.roles[0] | |
assistant_role = conv_template.roles[1] | |
idefics_2_message = [] | |
cur_img_idx = 0 | |
cur_vid_idx = 0 | |
all_videos = [x for x in images if isinstance(x, list)] | |
flatten_images = [] | |
for x in images: | |
if isinstance(x, list): | |
flatten_images.extend(x) | |
else: | |
flatten_images.append(x) | |
print(history) | |
for i, message in enumerate(history): | |
if message["role"] == user_role: | |
idefics_2_message.append({ | |
"role": user_role, | |
"content": [] | |
}) | |
message_text = message["text"] | |
num_video_tokens_in_text = message_text.count("<video>") | |
if num_video_tokens_in_text > 0: | |
for _ in range(num_video_tokens_in_text): | |
message_text = message_text.replace("<video>", "<image> " * len(all_videos[cur_vid_idx]), 1) | |
cur_vid_idx += 1 | |
num_image_tokens_in_text = message_text.count("<image>") | |
if num_image_tokens_in_text > 0: | |
sub_texts = [x.strip() for x in message_text.split("<image>")] | |
if sub_texts[0]: | |
idefics_2_message[-1]["content"].append({"type": "text", "text": sub_texts[0]}) | |
for sub_text in sub_texts[1:]: | |
idefics_2_message[-1]["content"].append({"type": "image"}) | |
if sub_text: | |
idefics_2_message.append({ | |
"role": user_role, | |
"content": [{"type": "text", "text": sub_text}] | |
}) | |
else: | |
idefics_2_message[-1]["content"].append({"type": "text", "text": message_text}) | |
elif message["role"] == assistant_role: | |
if i == len(history) - 1 and not message["text"]: | |
break | |
idefics_2_message.append({ | |
"role": assistant_role, | |
"content": [{"type": "text", "text": message["text"]}] | |
}) | |
if text: | |
assert idefics_2_message[-1]["role"] == assistant_role and not idefics_2_message[-1]["content"], "Internal error" | |
idefics_2_message.append({ | |
"role": user_role, | |
"content": [{"type": "text", "text": text}] | |
}) | |
print(idefics_2_message) | |
prompt = processor.apply_chat_template(idefics_2_message, add_generation_prompt=True) | |
images = [Image.open(x) if isinstance(x, str) else x for x in flatten_images] | |
inputs = processor(text=prompt, images=images, return_tensors="pt") | |
inputs = {k: v.to(model.device) for k, v in inputs.items()} | |
outputs = model.generate(**inputs, max_new_tokens=1024) | |
generated_text = processor.decode(outputs[0, inputs["input_ids"].shape[-1]:], skip_special_tokens=True) | |
return generated_text | |
def read_video_pyav(container, indices): | |
''' | |
Decode the video with PyAV decoder. | |
Args: | |
container (av.container.input.InputContainer): PyAV container. | |
indices (List[int]): List of frame indices to decode. | |
Returns: | |
np.ndarray: np array of decoded frames of shape (num_frames, height, width, 3). | |
''' | |
frames = [] | |
container.seek(0) | |
start_index = indices[0] | |
end_index = indices[-1] | |
for i, frame in enumerate(container.decode(video=0)): | |
if i > end_index: | |
break | |
if i >= start_index and i in indices: | |
frames.append(frame) | |
return np.stack([x.to_ndarray(format="rgb24") for x in frames]) | |
def eval_video(prompt, video:str): | |
container = av.open(video) | |
# sample uniformly 8 frames from the video | |
total_frames = container.streams.video[0].frames | |
if total_frames > MAX_NUM_FRAMES: | |
indices = np.arange(0, total_frames, total_frames / MAX_NUM_FRAMES).astype(int) | |
else: | |
indices = np.arange(total_frames) | |
video_frames = read_video_pyav(container, indices) | |
frames = [Image.fromarray(x) for x in video_frames] | |
eval_prompt = VIDEO_EVAL_PROMPT.format(text_prompt=prompt) | |
eval_prompt += "<video>" | |
user_role = conv_template.roles[0] | |
assistant_role = conv_template.roles[1] | |
chat_messages = [ | |
{ | |
"role": user_role, | |
"text": eval_prompt | |
}, | |
{ | |
"role": assistant_role, | |
"text": "" | |
} | |
] | |
response = generate(None, [frames], chat_messages) | |
return response | |
def build_demo(): | |
with gr.Blocks() as demo: | |
gr.Markdown(""" | |
## Video Evaluation | |
upload a video along with a text prompt when generating the video, this model will evaluate the video's quality from 7 different dimensions. | |
""") | |
with gr.Row(): | |
video = gr.Video(width=500, label="Video") | |
with gr.Column(): | |
eval_prompt_template = gr.Textbox(VIDEO_EVAL_PROMPT.strip(' \n'), label="Evaluation Prompt Template", interactive=False, max_lines=26) | |
video_prompt = gr.Textbox(label="Text Prompt", lines=1) | |
with gr.Row(): | |
eval_button = gr.Button("Evaluate Video") | |
clear_button = gr.ClearButton([video, video_prompt]) | |
eval_result = gr.Textbox(label="Evaluation result", interactive=False, lines=7) | |
eval_button.click( | |
eval_video, [video_prompt, video], [eval_result] | |
) | |
dummy_id = gr.Textbox("id", label="id", visible=False, min_width=50) | |
dummy_output = gr.Textbox("reference score", label="reference scores", visible=False, lines=7) | |
gr.Examples( | |
examples= | |
[ | |
[ | |
item['id'], | |
item['prompt'], | |
item['video'], | |
item['conversations'][1]['value'] | |
] for item in examples | |
], | |
inputs=[dummy_id, video_prompt, video, dummy_output], | |
) | |
# gr.Markdown(""" | |
# ## Citation | |
# ``` | |
# @article{jiang2024mantis, | |
# title={MANTIS: Interleaved Multi-Image Instruction Tuning}, | |
# author={Jiang, Dongfu and He, Xuan and Zeng, Huaye and Wei, Con and Ku, Max and Liu, Qian and Chen, Wenhu}, | |
# journal={arXiv preprint arXiv:2405.01483}, | |
# year={2024} | |
# } | |
# ```""") | |
return demo | |
if __name__ == "__main__": | |
demo = build_demo() | |
demo.launch(share=True) |