THUdyh commited on
Commit
117219c
·
verified ·
1 Parent(s): 42d2938

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +309 -3
README.md CHANGED
@@ -1,9 +1,315 @@
1
  ---
2
  license: apache-2.0
 
 
 
 
 
 
3
  ---
4
 
5
- Ola-7b
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
- Usage:
8
  1. Download the speech encoder at https://huggingface.co/THUdyh/Ola_speech_encoders.
9
- 2. Replace the path in config.json to local path of speech encoders.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  license: apache-2.0
3
+ base_model:
4
+ - Qwen/Qwen2.5-7B-Instruct
5
+ pipeline_tag: text-generation
6
+ language:
7
+ - en
8
+ - zh
9
  ---
10
 
11
+ # Ola-7B
12
+
13
+ ## Model Summary
14
+
15
+ The Ola-7B model is trained on text, image, video and audio data. Based on Qwen2.5 language model with a context window of 32K tokens, it can take both image/video, text and audio as input and output text/speech.
16
+
17
+ Ola offers an on-demand solution to seamlessly and efficiently process visual inputs with arbitrary spatial sizes and temporal lengths.
18
+
19
+ - **Repository:** https://github.com/xxxxx
20
+ - **Languages:** English, Chinese
21
+ - **Paper:** https://arxiv.org/abs/2501.xxxx
22
+
23
+ ## Use
24
 
 
25
  1. Download the speech encoder at https://huggingface.co/THUdyh/Ola_speech_encoders.
26
+ 2. Replace the path in config.json with local path of speech encoders.
27
+
28
+ We provide a simple generation process for using our model. For more details, please refer to our [Github Repo](xxxxxx)
29
+
30
+ ```
31
+ import os
32
+ os.environ['LOWRES_RESIZE'] = '384x32'
33
+ os.environ['HIGHRES_BASE'] = '0x32'
34
+ os.environ['VIDEO_RESIZE'] = "0x64"
35
+ os.environ['VIDEO_MAXRES'] = "480"
36
+ os.environ['VIDEO_MINRES'] = "288"
37
+ os.environ['MAXRES'] = '1536'
38
+ os.environ['MINRES'] = '0'
39
+ os.environ['REGIONAL_POOL'] = '2x'
40
+ os.environ['FORCE_NO_DOWNSAMPLE'] = '1'
41
+ os.environ['LOAD_VISION_EARLY'] = '1'
42
+ os.environ['SKIP_LOAD_VIT'] = '1'
43
+
44
+
45
+ import gradio as gr
46
+ import torch
47
+ import re
48
+ from decord import VideoReader, cpu
49
+ from PIL import Image
50
+ import numpy as np
51
+ import transformers
52
+ import moviepy.editor as mp
53
+ from typing import Dict, Optional, Sequence, List
54
+ import librosa
55
+ import whisper
56
+ from ola.conversation import conv_templates, SeparatorStyle
57
+ from ola.model.builder import load_pretrained_model
58
+ from ola.utils import disable_torch_init
59
+ from ola.datasets.preprocess import tokenizer_image_token, tokenizer_speech_image_token, tokenizer_speech_question_image_token
60
+ from ola.mm_utils import get_model_name_from_path, KeywordsStoppingCriteria, process_anyres_video, process_anyres_highres_image_genli
61
+ from ola.constants import IGNORE_INDEX, DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX, DEFAULT_SPEECH_TOKEN
62
+
63
+ model_path = ""
64
+ tokenizer, model, image_processor, _ = load_pretrained_model(model_path, None)
65
+ model = model.to('cuda').eval()
66
+ model = model.bfloat16()
67
+
68
+ USE_SPEECH=False
69
+ cur_dir = os.path.dirname(os.path.abspath(__file__))
70
+
71
+
72
+ def load_audio(audio_file_name):
73
+ speech_wav, samplerate = librosa.load(audio_file_name, sr=16000)
74
+ if len(speech_wav.shape) > 1:
75
+ speech_wav = speech_wav[:, 0]
76
+ speech_wav = speech_wav.astype(np.float32)
77
+ CHUNK_LIM = 480000
78
+ SAMPLE_RATE = 16000
79
+ speechs = []
80
+ speech_wavs = []
81
+
82
+ if len(speech_wav) <= CHUNK_LIM:
83
+ speech = whisper.pad_or_trim(speech_wav)
84
+ speech_wav = whisper.pad_or_trim(speech_wav)
85
+ speechs.append(speech)
86
+ speech_wavs.append(torch.from_numpy(speech_wav).unsqueeze(0))
87
+ else:
88
+ for i in range(0, len(speech_wav), CHUNK_LIM):
89
+ chunk = speech_wav[i : i + CHUNK_LIM]
90
+ if len(chunk) < CHUNK_LIM:
91
+ chunk = whisper.pad_or_trim(chunk)
92
+ speechs.append(chunk)
93
+ speech_wavs.append(torch.from_numpy(chunk).unsqueeze(0))
94
+ mels = []
95
+ for chunk in speechs:
96
+ chunk = whisper.log_mel_spectrogram(chunk, n_mels=128).permute(1, 0).unsqueeze(0)
97
+ mels.append(chunk)
98
+
99
+ mels = torch.cat(mels, dim=0)
100
+ speech_wavs = torch.cat(speech_wavs, dim=0)
101
+ if mels.shape[0] > 25:
102
+ mels = mels[:25]
103
+ speech_wavs = speech_wavs[:25]
104
+
105
+ speech_length = torch.LongTensor([mels.shape[1]] * mels.shape[0])
106
+ speech_chunks = torch.LongTensor([mels.shape[0]])
107
+ return mels, speech_length, speech_chunks, speech_wavs
108
+
109
+ def extract_audio(videos_file_path):
110
+ my_clip = mp.VideoFileClip(videos_file_path)
111
+ return my_clip.audio
112
+
113
+ def ola_inference(multimodal, audio_path):
114
+ visual, text = multimodal["files"][0], multimodal["text"]
115
+ if visual.endswith("image2.png"):
116
+ modality = "video"
117
+ visual = f"{cur_dir}/case/case1.mp4"
118
+ if visual.endswith(".mp4"):
119
+ modality = "video"
120
+ else:
121
+ modality = "image"
122
+
123
+ # input audio and video, do not parse audio in the video, else parse audio in the video
124
+ if audio_path:
125
+ USE_SPEECH = True
126
+ elif modality == "video":
127
+ USE_SPEECH = True
128
+ else:
129
+ USE_SPEECH = False
130
+
131
+ speechs = []
132
+ speech_lengths = []
133
+ speech_wavs = []
134
+ speech_chunks = []
135
+ if modality == "video":
136
+ vr = VideoReader(visual, ctx=cpu(0))
137
+ total_frame_num = len(vr)
138
+ fps = round(vr.get_avg_fps())
139
+ uniform_sampled_frames = np.linspace(0, total_frame_num - 1, 64, dtype=int)
140
+ frame_idx = uniform_sampled_frames.tolist()
141
+ spare_frames = vr.get_batch(frame_idx).asnumpy()
142
+ video = [Image.fromarray(frame) for frame in spare_frames]
143
+ else:
144
+ image = [Image.open(visual)]
145
+ image_sizes = [image[0].size]
146
+
147
+ if USE_SPEECH and audio_path:
148
+ audio_path = audio_path
149
+ speech, speech_length, speech_chunk, speech_wav = load_audio(audio_path)
150
+ speechs.append(speech.bfloat16().to('cuda'))
151
+ speech_lengths.append(speech_length.to('cuda'))
152
+ speech_chunks.append(speech_chunk.to('cuda'))
153
+ speech_wavs.append(speech_wav.to('cuda'))
154
+ print('load audio')
155
+ elif USE_SPEECH and not audio_path:
156
+ # parse audio in the video
157
+ audio = extract_audio(visual)
158
+ audio.write_audiofile("./video_audio.wav")
159
+ video_audio_path = './video_audio.wav'
160
+ speech, speech_length, speech_chunk, speech_wav = load_audio(video_audio_path)
161
+ speechs.append(speech.bfloat16().to('cuda'))
162
+ speech_lengths.append(speech_length.to('cuda'))
163
+ speech_chunks.append(speech_chunk.to('cuda'))
164
+ speech_wavs.append(speech_wav.to('cuda'))
165
+ else:
166
+ speechs = [torch.zeros(1, 3000, 128).bfloat16().to('cuda')]
167
+ speech_lengths = [torch.LongTensor([3000]).to('cuda')]
168
+ speech_wavs = [torch.zeros([1, 480000]).to('cuda')]
169
+ speech_chunks = [torch.LongTensor([1]).to('cuda')]
170
+
171
+ conv_mode = "qwen_1_5"
172
+ if text:
173
+ qs = text
174
+ else:
175
+ qs = ''
176
+ if USE_SPEECH and audio_path:
177
+ qs = DEFAULT_IMAGE_TOKEN + "\n" + "User's question in speech: " + DEFAULT_SPEECH_TOKEN + '\n'
178
+ elif USE_SPEECH:
179
+ qs = DEFAULT_SPEECH_TOKEN + DEFAULT_IMAGE_TOKEN + "\n" + qs
180
+ else:
181
+ qs = DEFAULT_IMAGE_TOKEN + "\n" + qs
182
+
183
+ conv = conv_templates[conv_mode].copy()
184
+ conv.append_message(conv.roles[0], qs)
185
+ conv.append_message(conv.roles[1], None)
186
+ prompt = conv.get_prompt()
187
+ if USE_SPEECH and audio_path:
188
+ input_ids = tokenizer_speech_question_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to('cuda')
189
+ elif USE_SPEECH:
190
+ input_ids = tokenizer_speech_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to('cuda')
191
+ else:
192
+ input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to('cuda')
193
+
194
+ if modality == "video":
195
+ video_processed = []
196
+ for idx, frame in enumerate(video):
197
+ image_processor.do_resize = False
198
+ image_processor.do_center_crop = False
199
+ frame = process_anyres_video(frame, image_processor)
200
+
201
+ if frame_idx is not None and idx in frame_idx:
202
+ video_processed.append(frame.unsqueeze(0))
203
+ elif frame_idx is None:
204
+ video_processed.append(frame.unsqueeze(0))
205
+
206
+ if frame_idx is None:
207
+ frame_idx = np.arange(0, len(video_processed), dtype=int).tolist()
208
+
209
+ video_processed = torch.cat(video_processed, dim=0).bfloat16().to("cuda")
210
+ video_processed = (video_processed, video_processed)
211
+
212
+ video_data = (video_processed, (384, 384), "video")
213
+ else:
214
+ image_processor.do_resize = False
215
+ image_processor.do_center_crop = False
216
+ image_tensor, image_highres_tensor = [], []
217
+ for visual in image:
218
+ image_tensor_, image_highres_tensor_ = process_anyres_highres_image_genli(visual, image_processor)
219
+ image_tensor.append(image_tensor_)
220
+ image_highres_tensor.append(image_highres_tensor_)
221
+ if all(x.shape == image_tensor[0].shape for x in image_tensor):
222
+ image_tensor = torch.stack(image_tensor, dim=0)
223
+ if all(x.shape == image_highres_tensor[0].shape for x in image_highres_tensor):
224
+ image_highres_tensor = torch.stack(image_highres_tensor, dim=0)
225
+ if type(image_tensor) is list:
226
+ image_tensor = [_image.bfloat16().to("cuda") for _image in image_tensor]
227
+ else:
228
+ image_tensor = image_tensor.bfloat16().to("cuda")
229
+ if type(image_highres_tensor) is list:
230
+ image_highres_tensor = [_image.bfloat16().to("cuda") for _image in image_highres_tensor]
231
+ else:
232
+ image_highres_tensor = image_highres_tensor.bfloat16().to("cuda")
233
+
234
+ pad_token_ids = 151643
235
+
236
+ attention_masks = input_ids.ne(pad_token_ids).long().to('cuda')
237
+ stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
238
+ keywords = [stop_str]
239
+ stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
240
+
241
+ gen_kwargs = {}
242
+
243
+ if "max_new_tokens" not in gen_kwargs:
244
+ gen_kwargs["max_new_tokens"] = 1024
245
+ if "temperature" not in gen_kwargs:
246
+ gen_kwargs["temperature"] = 0.2
247
+ if "top_p" not in gen_kwargs:
248
+ gen_kwargs["top_p"] = None
249
+ if "num_beams" not in gen_kwargs:
250
+ gen_kwargs["num_beams"] = 1
251
+
252
+ with torch.inference_mode():
253
+ if modality == "video":
254
+ output_ids = model.generate(
255
+ inputs=input_ids,
256
+ images=video_data[0][0],
257
+ images_highres=video_data[0][1],
258
+ modalities=video_data[2],
259
+ speech=speechs,
260
+ speech_lengths=speech_lengths,
261
+ speech_chunks=speech_chunks,
262
+ speech_wav=speech_wavs,
263
+ attention_mask=attention_masks,
264
+ use_cache=True,
265
+ stopping_criteria=[stopping_criteria],
266
+ do_sample=True if gen_kwargs["temperature"] > 0 else False,
267
+ temperature=gen_kwargs["temperature"],
268
+ top_p=gen_kwargs["top_p"],
269
+ num_beams=gen_kwargs["num_beams"],
270
+ max_new_tokens=gen_kwargs["max_new_tokens"],
271
+ )
272
+ else:
273
+ output_ids = model.generate(
274
+ inputs=input_ids,
275
+ images=image_tensor,
276
+ images_highres=image_highres_tensor,
277
+ image_sizes=image_sizes,
278
+ modalities=['image'],
279
+ speech=speechs,
280
+ speech_lengths=speech_lengths,
281
+ speech_chunks=speech_chunks,
282
+ speech_wav=speech_wavs,
283
+ attention_mask=attention_masks,
284
+ use_cache=True,
285
+ stopping_criteria=[stopping_criteria],
286
+ do_sample=True if gen_kwargs["temperature"] > 0 else False,
287
+ temperature=gen_kwargs["temperature"],
288
+ top_p=gen_kwargs["top_p"],
289
+ num_beams=gen_kwargs["num_beams"],
290
+ max_new_tokens=gen_kwargs["max_new_tokens"],
291
+ )
292
+
293
+ outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
294
+ outputs = outputs.strip()
295
+ if outputs.endswith(stop_str):
296
+ outputs = outputs[:-len(stop_str)]
297
+ outputs = outputs.strip()
298
+ return outputs, None
299
+ ```
300
+
301
+
302
+
303
+ ### Model Architecture
304
+
305
+ - **Architecture:** Pre-trained [Oryx-ViT](https://huggingface.co/THUdyh/Oryx-ViT) + Qwen2.5-7B
306
+ - **Data:** a mixture of more than 5M image/video/audio data, training for 3 stage.
307
+ - **Precision:** BFloat16
308
+
309
+ #### Hardware & Software
310
+
311
+ - **Hardware:** 64 * NVIDIA Tesla A100
312
+ - **Orchestration:** HuggingFace Trainer
313
+ - **Code:** Pytorch
314
+
315
+ ## Citation