Chat-UniVi commited on
Commit
595ff0d
1 Parent(s): e91dd0b

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +265 -0
README.md CHANGED
@@ -1,3 +1,268 @@
1
  ---
2
  license: llama2
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  license: llama2
3
  ---
4
+ ---
5
+ license: llama2
6
+ ---
7
+ ## 😮 Highlights
8
+
9
+ ### 💡 Unified visual representation for image and video
10
+ We employ **a set of dynamic visual tokens** to uniformly represent images and videos.
11
+ This representation framework empowers the model to efficiently utilize **a limited number of visual tokens** to simultaneously capture **the spatial details necessary for images** and **the comprehensive temporal relationship required for videos**.
12
+
13
+ ### 🔥 Joint training strategy, making LLMs understand both image and video
14
+ Chat-UniVi is trained on a mixed dataset containing both images and videos, allowing direct application to tasks involving both mediums without requiring any modifications.
15
+
16
+ ### 🤗 High performance, complementary learning with image and video
17
+ Extensive experimental results demonstrate that Chat-UniVi, as a unified model, consistently outperforms even existing methods exclusively designed for either images or videos.
18
+
19
+
20
+ ### Inference for Video Understanding
21
+ ```python
22
+ import torch
23
+ import os
24
+ from ChatUniVi.constants import *
25
+ from ChatUniVi.conversation import conv_templates, SeparatorStyle
26
+ from ChatUniVi.model.builder import load_pretrained_model
27
+ from ChatUniVi.utils import disable_torch_init
28
+ from ChatUniVi.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
29
+ from PIL import Image
30
+ from decord import VideoReader, cpu
31
+ import numpy as np
32
+
33
+
34
+ def _get_rawvideo_dec(video_path, image_processor, max_frames=MAX_IMAGE_LENGTH, image_resolution=224, video_framerate=1, s=None, e=None):
35
+ # speed up video decode via decord.
36
+ video_mask = np.zeros(max_frames, dtype=np.int64)
37
+ max_video_length = 0
38
+
39
+ # T x 3 x H x W
40
+ video = np.zeros((max_frames, 3, image_resolution, image_resolution), dtype=np.float64)
41
+
42
+ if s is None:
43
+ start_time, end_time = None, None
44
+ else:
45
+ start_time = int(s)
46
+ end_time = int(e)
47
+ start_time = start_time if start_time >= 0. else 0.
48
+ end_time = end_time if end_time >= 0. else 0.
49
+ if start_time > end_time:
50
+ start_time, end_time = end_time, start_time
51
+ elif start_time == end_time:
52
+ end_time = start_time + 1
53
+
54
+ if os.path.exists(video_path):
55
+ vreader = VideoReader(video_path, ctx=cpu(0))
56
+ else:
57
+ print(video_path)
58
+ raise FileNotFoundError
59
+
60
+ fps = vreader.get_avg_fps()
61
+ f_start = 0 if start_time is None else int(start_time * fps)
62
+ f_end = int(min(1000000000 if end_time is None else end_time * fps, len(vreader) - 1))
63
+ num_frames = f_end - f_start + 1
64
+ if num_frames > 0:
65
+ # T x 3 x H x W
66
+ sample_fps = int(video_framerate)
67
+ t_stride = int(round(float(fps) / sample_fps))
68
+
69
+ all_pos = list(range(f_start, f_end + 1, t_stride))
70
+ if len(all_pos) > max_frames:
71
+ sample_pos = [all_pos[_] for _ in np.linspace(0, len(all_pos) - 1, num=max_frames, dtype=int)]
72
+ else:
73
+ sample_pos = all_pos
74
+
75
+ patch_images = [Image.fromarray(f) for f in vreader.get_batch(sample_pos).asnumpy()]
76
+
77
+ patch_images = torch.stack([image_processor.preprocess(img, return_tensors='pt')['pixel_values'][0] for img in patch_images])
78
+ slice_len = patch_images.shape[0]
79
+
80
+ max_video_length = max_video_length if max_video_length > slice_len else slice_len
81
+ if slice_len < 1:
82
+ pass
83
+ else:
84
+ video[:slice_len, ...] = patch_images
85
+
86
+ return patch_images, video_mask
87
+ else:
88
+ print("video path: {} error.".format(video_path))
89
+
90
+ video_mask[:max_video_length] = [1] * max_video_length
91
+
92
+ return torch.from_numpy(video), video_mask
93
+
94
+ if __name__ == '__main__':
95
+ # Model Parameter
96
+ model_path = ${model_path}
97
+ video_path = ${video_path}
98
+ max_frames = ${max_frames}
99
+
100
+ # Input Text
101
+ qs = "Describe the video."
102
+
103
+ # Sampling Parameter
104
+ conv_mode = "simple"
105
+ temperature = 0.2
106
+ top_p = None
107
+ num_beams = 1
108
+
109
+ disable_torch_init()
110
+ model_path = os.path.expanduser(model_path)
111
+ model_name = "ChatUniVi"
112
+ tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, None, model_name)
113
+
114
+ mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
115
+ mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True)
116
+ if mm_use_im_patch_token:
117
+ tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
118
+ if mm_use_im_start_end:
119
+ tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
120
+ model.resize_token_embeddings(len(tokenizer))
121
+
122
+ vision_tower = model.get_vision_tower()
123
+ if not vision_tower.is_loaded:
124
+ vision_tower.load_model()
125
+ image_processor = vision_tower.image_processor
126
+
127
+ if model.config.config["use_cluster"]:
128
+ for n, m in model.named_modules():
129
+ m = m.to(dtype=torch.bfloat16)
130
+
131
+ # Check if the video exists
132
+ if video_path is not None:
133
+ video_frames, _ = _get_rawvideo_dec(video_path, image_processor, max_frames=max_frames)
134
+
135
+ cur_prompt = qs
136
+ if model.config.mm_use_im_start_end:
137
+ qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN * MAX_IMAGE_LENGTH + DEFAULT_IM_END_TOKEN + '\n' + qs
138
+ else:
139
+ qs = DEFAULT_IMAGE_TOKEN * MAX_IMAGE_LENGTH + '\n' + qs
140
+
141
+ conv = conv_templates[conv_mode].copy()
142
+ conv.append_message(conv.roles[0], qs)
143
+ conv.append_message(conv.roles[1], None)
144
+ prompt = conv.get_prompt()
145
+
146
+ input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(
147
+ 0).cuda()
148
+
149
+ stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
150
+ keywords = [stop_str]
151
+ stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
152
+
153
+ with torch.inference_mode():
154
+ output_ids = model.generate(
155
+ input_ids,
156
+ images=video_frames.half().cuda(),
157
+ do_sample=True,
158
+ temperature=temperature,
159
+ top_p=top_p,
160
+ num_beams=num_beams,
161
+ output_scores=True,
162
+ return_dict_in_generate=True,
163
+ max_new_tokens=1024,
164
+ use_cache=True,
165
+ stopping_criteria=[stopping_criteria])
166
+
167
+ output_ids = output_ids.sequences
168
+ input_token_len = input_ids.shape[1]
169
+ n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
170
+ if n_diff_input_output > 0:
171
+ print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
172
+ outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]
173
+ outputs = outputs.strip()
174
+ if outputs.endswith(stop_str):
175
+ outputs = outputs[:-len(stop_str)]
176
+ outputs = outputs.strip()
177
+ print(outputs)
178
+ ```
179
+
180
+ ### Inference for Image Understanding
181
+ ```python
182
+ import torch
183
+ import os
184
+ from ChatUniVi.constants import *
185
+ from ChatUniVi.conversation import conv_templates, SeparatorStyle
186
+ from ChatUniVi.model.builder import load_pretrained_model
187
+ from ChatUniVi.utils import disable_torch_init
188
+ from ChatUniVi.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
189
+ from PIL import Image
190
+
191
+
192
+ if __name__ == '__main__':
193
+ # Model Parameter
194
+ model_path = ${model_path}
195
+ image_path = ${image_path}
196
+
197
+ # Input Text
198
+ qs = "Describe the image."
199
+
200
+ # Sampling Parameter
201
+ conv_mode = "simple"
202
+ temperature = 0.2
203
+ top_p = None
204
+ num_beams = 1
205
+
206
+ disable_torch_init()
207
+ model_path = os.path.expanduser(model_path)
208
+ model_name = "ChatUniVi"
209
+ tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, None, model_name)
210
+
211
+ mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
212
+ mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True)
213
+ if mm_use_im_patch_token:
214
+ tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
215
+ if mm_use_im_start_end:
216
+ tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
217
+ model.resize_token_embeddings(len(tokenizer))
218
+
219
+ vision_tower = model.get_vision_tower()
220
+ if not vision_tower.is_loaded:
221
+ vision_tower.load_model()
222
+ image_processor = vision_tower.image_processor
223
+
224
+ # Check if the video exists
225
+ if image_path is not None:
226
+ cur_prompt = qs
227
+ if model.config.mm_use_im_start_end:
228
+ qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs
229
+ else:
230
+ qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
231
+
232
+ conv = conv_templates[conv_mode].copy()
233
+ conv.append_message(conv.roles[0], qs)
234
+ conv.append_message(conv.roles[1], None)
235
+ prompt = conv.get_prompt()
236
+
237
+ input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
238
+
239
+ image = Image.open(image_path)
240
+ image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
241
+
242
+ stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
243
+ keywords = [stop_str]
244
+ stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
245
+
246
+ with torch.inference_mode():
247
+ output_ids = model.generate(
248
+ input_ids,
249
+ images=image_tensor.unsqueeze(0).half().cuda(),
250
+ do_sample=True,
251
+ temperature=temperature,
252
+ top_p=top_p,
253
+ num_beams=num_beams,
254
+ max_new_tokens=1024,
255
+ use_cache=True,
256
+ stopping_criteria=[stopping_criteria])
257
+
258
+ input_token_len = input_ids.shape[1]
259
+ n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
260
+ if n_diff_input_output > 0:
261
+ print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
262
+ outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]
263
+ outputs = outputs.strip()
264
+ if outputs.endswith(stop_str):
265
+ outputs = outputs[:-len(stop_str)]
266
+ outputs = outputs.strip()
267
+ print(outputs)
268
+ ```