File size: 7,866 Bytes
8e1010d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
import argparse
import torch
import os
import json
from tqdm import tqdm
import shortuuid

from grounding_qwen import GroundQwenForCausalLM
from constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from conversation import conv_templates, SeparatorStyle
from utils import disable_torch_init
from mm_utils import tokenizer_image_token, process_images, get_model_name_from_path
from torch.utils.data import Dataset, DataLoader

from PIL import Image
import math
import pickle
from decord import VideoReader
import numpy as np

from transformers import StoppingCriteria, StoppingCriteriaList, AutoTokenizer
from petrel_client.client import Client
client = Client('~/petreloss.conf')


def get_seq_frames(total_num_frames, desired_num_frames):
    seg_size = float(total_num_frames - 1) / desired_num_frames
    seq = []
    for i in range(desired_num_frames):
        start = int(np.round(seg_size * i))
        end = int(np.round(seg_size * (i + 1)))
        seq.append((start + end) // 2)

    return seq


def get_seq_time(vr, frame_idx, num_clip):
    frm_per_clip = len(frame_idx) // num_clip
    key_frame = [[frame_idx[i*frm_per_clip], frame_idx[i*frm_per_clip+frm_per_clip-1]] for i in range(num_clip)]
    time = vr.get_frame_timestamp(key_frame)
    return np.hstack([time[:, 0, 0], time[:, 1, 1]])


def calculate_diff(scene_sep, start_frame):
    diff = [scene_sep[0]-start_frame]
    for i in range(len(scene_sep)-1):
        diff.append(scene_sep[i+1]-scene_sep[i])
    return diff


def load_video(vis_path, scene_sep, num_frm=16, max_clip=4):
    block_size = 1
    vr = VideoReader(vis_path)
    total_frame_num = len(vr)
    fps = vr.get_avg_fps()
    total_time = total_frame_num / fps

    if len(scene_sep) == 0:
        num_clip = total_time / num_frm
        num_clip = int(block_size*np.round(num_clip/block_size)) if num_clip > block_size else int(np.round(num_clip))
        num_clip = max(num_clip, 5)
        num_clip = min(num_clip, max_clip)
        total_num_frm = num_frm * num_clip
        start_frame = 0
        frame_idx = get_seq_frames(total_frame_num, total_num_frm)
    else:
        num_clip = max(len(scene_sep), 5)
        num_clip = min(num_clip, max_clip)
        total_num_frm = num_frm * num_clip
        start_frame = 0
        frame_idx = []
        if len(scene_sep) < 5:
            diff = calculate_diff(scene_sep, start_frame)
            new_sep = max(diff) / (5-len(scene_sep)+1)
            max_idx = np.argmax(diff)
            if max_idx == 0:
                scene_sep = [int(start_frame+new_sep*i) for i in range(1, 5-len(scene_sep)+1)] + scene_sep
            else:
                scene_sep = scene_sep[:max_idx]+[int(scene_sep[max_idx-1]+i*new_sep) for i in range(1, 5-len(scene_sep)+1)] + scene_sep[max_idx:]
        elif len(scene_sep) > max_clip:
            diff = calculate_diff(scene_sep, start_frame)
            min_idx = np.argsort(diff[:-1])[:len(scene_sep)-max_clip] ##minimum diff to remove
            for i in np.sort(min_idx)[::-1]:
                del scene_sep[i]

        start_ = start_frame
        for end_frame in scene_sep:
            idx_list = np.linspace(start_, end_frame, num=num_frm, endpoint=False)
            frame_idx.extend([int(id) for id in idx_list])
            start_ = end_frame

        time_idx = get_seq_time(vr, frame_idx, num_clip)
        img_array = vr.get_batch(frame_idx).asnumpy()  # (n_clips*num_frm, H, W, 3)

    a, H, W, _ = img_array.shape
    h, w = 224, 224
    if img_array.shape[-3] != h or img_array.shape[-2] != w:
        img_array = torch.from_numpy(img_array).permute(0, 3, 1, 2).float()
        img_array = torch.nn.functional.interpolate(img_array, size=(h, w))
        img_array = img_array.permute(0, 2, 3, 1).to(torch.uint8).numpy()
    img_array = img_array.reshape((1, total_num_frm, img_array.shape[-3], img_array.shape[-2], img_array.shape[-1]))

    clip_imgs = []
    for j in range(total_num_frm):
        clip_imgs.append(Image.fromarray(img_array[0, j]))

    return clip_imgs, time_idx, num_clip


def preprocess_time(time, num_clip, tokenizer):
    time = time.reshape(2, num_clip)
    seq = []

    block_size = 1
    for i in range(num_clip):
        start, end = time[:, i]
        start = int(np.round(start))
        end = int(np.round(end))
        if (i+1) % block_size == 0:
            history_end = end
        sentence = 'This contains a clip sampled in %d to %d seconds' % (start, end) + DEFAULT_IMAGE_TOKEN
        sentence = tokenizer_image_token(sentence, tokenizer, return_tensors='pt')
        seq.append(sentence)
    return seq


def preprocess_question(questions, tokenizer):
    seq = []
    for q in questions:
        sentence = tokenizer_image_token(q, tokenizer, return_tensors='pt')
        seq.append(sentence)
    
    return seq



def eval_dataset(args):
    disable_torch_init()
    model_path = os.path.expanduser(args.model_path)
    model_name = get_model_name_from_path(model_path)

    device = 'cuda'
    kwargs = {"device_map": 'auto'}
    kwargs['torch_dtype'] = torch.float16

    tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
    model = GroundQwenForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
    mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
    mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True)
    if mm_use_im_patch_token:
        tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
    if mm_use_im_start_end:
        tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
    model.resize_token_embeddings(len(tokenizer))

    vision_tower = model.get_vision_tower()
    if not vision_tower.is_loaded:
        vision_tower.load_model()
    vision_tower.to(device=device, dtype=torch.float16)
    image_processor = vision_tower.image_processor

    if hasattr(model.config, "max_sequence_length"):
        context_len = model.config.max_sequence_length
    else:
        context_len = 2048

    video_path = '/mnt/hwfile/mllm/qianrui/test_video/air.mp4'
    question = ['How many pilots are shown in the video?', 'What airlines are shown in the video?', 'How is the decoration of the airport?']
    scene_sep = json.load(open('/mnt/hwfile/mllm/qianrui/test_video/air.json', 'r'))['scene_sep']
    frames, time_idx, num_clips = load_video(video_path, scene_sep, num_frm=16, max_clip=32)
    video = image_processor.preprocess(frames, return_tensors='pt')['pixel_values']
    video = video.view(num_clips, 16, *video.shape[1:])
    seqs = preprocess_time(time_idx, num_clips, tokenizer)
    seqs = torch.nn.utils.rnn.pad_sequence(
        seqs, 
        batch_first=True,
        padding_value=tokenizer.pad_token_id)
    compress_mask = seqs.ne(tokenizer.pad_token_id)
    question = preprocess_question(question, tokenizer)
    question = torch.nn.utils.rnn.pad_sequence(
        question, 
        batch_first=True,
        padding_value=tokenizer.pad_token_id)
    qs_mask = question.ne(tokenizer.pad_token_id)

    with torch.inference_mode():
        similarity = model.forward_grounding(
            input_ids=seqs.to(device='cuda', non_blocking=True),
            attention_mask=compress_mask.to(device='cuda', non_blocking=True),
            images=video.to(dtype=torch.float16, device='cuda', non_blocking=True),
            qs_ids=question.to(device='cuda', non_blocking=True),
            qs_mask=qs_mask.to(device='cuda', non_blocking=True))
    
    print(similarity.shape, similarity)

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
    parser.add_argument("--conv-mode", type=str, default="llava_v1")
    args = parser.parse_args()

    eval_dataset(args)