舟勤 commited on
Commit
45d16e9
1 Parent(s): c778599
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. README.md +14 -0
  2. app.py +192 -0
  3. ckpt/blip2_pretrained_flant5xxl.pth +3 -0
  4. ckpt/finetune-vicuna7b-v2-nofrozen_imageQ.pth +3 -0
  5. ckpt/finetune-vicuna7b-v2.pth +3 -0
  6. ckpt/pretrain-billa7b-zh.pth +3 -0
  7. eval_configs/video_llama_eval.yaml +32 -0
  8. requirements.txt +11 -0
  9. video_llama/__init__.py +31 -0
  10. video_llama/app.py +192 -0
  11. video_llama/ckpt/blip2_pretrained_flant5xxl.pth +3 -0
  12. video_llama/ckpt/finetune-vicuna7b-v2-nofrozen_imageQ.pth +3 -0
  13. video_llama/ckpt/pretrain-billa7b-zh.pth +3 -0
  14. video_llama/common/__init__.py +0 -0
  15. video_llama/common/config.py +468 -0
  16. video_llama/common/dist_utils.py +137 -0
  17. video_llama/common/gradcam.py +24 -0
  18. video_llama/common/logger.py +195 -0
  19. video_llama/common/optims.py +119 -0
  20. video_llama/common/registry.py +329 -0
  21. video_llama/common/utils.py +424 -0
  22. video_llama/configs/datasets/cc_sbu/align.yaml +5 -0
  23. video_llama/configs/datasets/cc_sbu/defaults.yaml +5 -0
  24. video_llama/configs/datasets/instruct/llava_instruct.yaml +6 -0
  25. video_llama/configs/datasets/instruct/webvid_instruct.yaml +6 -0
  26. video_llama/configs/datasets/laion/defaults.yaml +5 -0
  27. video_llama/configs/datasets/webvid/defaults.yaml +6 -0
  28. video_llama/configs/default.yaml +5 -0
  29. video_llama/configs/models/minigpt4.yaml +33 -0
  30. video_llama/configs/models/video_llama.yaml +36 -0
  31. video_llama/conversation/__init__.py +0 -0
  32. video_llama/conversation/conversation_video.py +248 -0
  33. video_llama/datasets/__init__.py +0 -0
  34. video_llama/datasets/builders/__init__.py +77 -0
  35. video_llama/datasets/builders/base_dataset_builder.py +236 -0
  36. video_llama/datasets/builders/image_text_pair_builder.py +106 -0
  37. video_llama/datasets/builders/instruct_builder.py +78 -0
  38. video_llama/datasets/builders/video_caption_builder.py +34 -0
  39. video_llama/datasets/data_utils.py +196 -0
  40. video_llama/datasets/datasets/__init__.py +0 -0
  41. video_llama/datasets/datasets/base_dataset.py +68 -0
  42. video_llama/datasets/datasets/caption_datasets.py +85 -0
  43. video_llama/datasets/datasets/cc_sbu_dataset.py +49 -0
  44. video_llama/datasets/datasets/dataloader_utils.py +162 -0
  45. video_llama/datasets/datasets/laion_dataset.py +31 -0
  46. video_llama/datasets/datasets/llava_instruct_dataset.py +228 -0
  47. video_llama/datasets/datasets/video_instruct_dataset.py +253 -0
  48. video_llama/datasets/datasets/webvid_datasets.py +122 -0
  49. video_llama/models/Qformer.py +1217 -0
  50. video_llama/models/__init__.py +201 -0
README.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Video LLaMA
3
+ emoji: 🚀
4
+ colorFrom: purple
5
+ colorTo: gray
6
+ sdk: gradio
7
+ sdk_version: 3.29.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: other
11
+ ---
12
+
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
14
+
app.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Adapted from: https://github.com/Vision-CAIR/MiniGPT-4/blob/main/demo.py
3
+ """
4
+ import argparse
5
+ import os
6
+ import random
7
+
8
+ import numpy as np
9
+ import torch
10
+ import torch.backends.cudnn as cudnn
11
+ import gradio as gr
12
+
13
+ from video_llama.common.config import Config
14
+ from video_llama.common.dist_utils import get_rank
15
+ from video_llama.common.registry import registry
16
+ from video_llama.conversation.conversation_video import Chat, Conversation, default_conversation,SeparatorStyle
17
+ import decord
18
+ decord.bridge.set_bridge('torch')
19
+
20
+ #%%
21
+ # imports modules for registration
22
+ from video_llama.datasets.builders import *
23
+ from video_llama.models import *
24
+ from video_llama.processors import *
25
+ from video_llama.runners import *
26
+ from video_llama.tasks import *
27
+
28
+ #%%
29
+ def parse_args():
30
+ parser = argparse.ArgumentParser(description="Demo")
31
+ parser.add_argument("--cfg-path", default='eval_configs/video_llama_eval.yaml', help="path to configuration file.")
32
+ parser.add_argument("--gpu-id", type=int, default=0, help="specify the gpu to load the model.")
33
+ parser.add_argument(
34
+ "--options",
35
+ nargs="+",
36
+ help="override some settings in the used config, the key-value pair "
37
+ "in xxx=yyy format will be merged into config file (deprecate), "
38
+ "change to --cfg-options instead.",
39
+ )
40
+ args = parser.parse_args()
41
+ return args
42
+
43
+
44
+ def setup_seeds(config):
45
+ seed = config.run_cfg.seed + get_rank()
46
+
47
+ random.seed(seed)
48
+ np.random.seed(seed)
49
+ torch.manual_seed(seed)
50
+
51
+ cudnn.benchmark = False
52
+ cudnn.deterministic = True
53
+
54
+
55
+ # ========================================
56
+ # Model Initialization
57
+ # ========================================
58
+
59
+ print('Initializing Chat')
60
+ args = parse_args()
61
+ cfg = Config(args)
62
+
63
+ model_config = cfg.model_cfg
64
+ model_config.device_8bit = args.gpu_id
65
+ model_cls = registry.get_model_class(model_config.arch)
66
+ model = model_cls.from_config(model_config).to('cuda:{}'.format(args.gpu_id))
67
+
68
+ vis_processor_cfg = cfg.datasets_cfg.webvid.vis_processor.train
69
+ vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg)
70
+ chat = Chat(model, vis_processor, device='cuda:{}'.format(args.gpu_id))
71
+ print('Initialization Finished')
72
+
73
+ # ========================================
74
+ # Gradio Setting
75
+ # ========================================
76
+
77
+ def gradio_reset(chat_state, img_list):
78
+ if chat_state is not None:
79
+ chat_state.messages = []
80
+ if img_list is not None:
81
+ img_list = []
82
+ return None, gr.update(value=None, interactive=True), gr.update(value=None, interactive=True), gr.update(placeholder='Please upload your video first', interactive=False),gr.update(value="Upload & Start Chat", interactive=True), chat_state, img_list
83
+
84
+ def upload_imgorvideo(gr_video, gr_img, text_input, chat_state):
85
+ if gr_img is None and gr_video is None:
86
+ return None, None, None, gr.update(interactive=True), chat_state, None
87
+ elif gr_img is not None and gr_video is None:
88
+ print(gr_img)
89
+ chat_state = Conversation(
90
+ system= "You are able to understand the visual content that the user provides."
91
+ "Follow the instructions carefully and explain your answers in detail.",
92
+ roles=("Human", "Assistant"),
93
+ messages=[],
94
+ offset=0,
95
+ sep_style=SeparatorStyle.SINGLE,
96
+ sep="###",
97
+ )
98
+ img_list = []
99
+ llm_message = chat.upload_img(gr_img, chat_state, img_list)
100
+ return gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=True, placeholder='Type and press Enter'), gr.update(value="Start Chatting", interactive=False), chat_state, img_list
101
+ elif gr_video is not None and gr_img is None:
102
+ print(gr_video)
103
+ chat_state = default_conversation.copy()
104
+ chat_state = Conversation(
105
+ system= "You are able to understand the visual content that the user provides."
106
+ "Follow the instructions carefully and explain your answers in detail.",
107
+ roles=("Human", "Assistant"),
108
+ messages=[],
109
+ offset=0,
110
+ sep_style=SeparatorStyle.SINGLE,
111
+ sep="###",
112
+ )
113
+ img_list = []
114
+ llm_message = chat.upload_video(gr_video, chat_state, img_list)
115
+ return gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=True, placeholder='Type and press Enter'), gr.update(value="Start Chatting", interactive=False), chat_state, img_list
116
+ else:
117
+ # img_list = []
118
+ return gr.update(interactive=False), gr.update(interactive=False, placeholder='Currently, only one input is supported'), gr.update(value="Currently, only one input is supported", interactive=False), chat_state, None
119
+
120
+ def gradio_ask(user_message, chatbot, chat_state):
121
+ if len(user_message) == 0:
122
+ return gr.update(interactive=True, placeholder='Input should not be empty!'), chatbot, chat_state
123
+ chat.ask(user_message, chat_state)
124
+ chatbot = chatbot + [[user_message, None]]
125
+ return '', chatbot, chat_state
126
+
127
+
128
+ def gradio_answer(chatbot, chat_state, img_list, num_beams, temperature):
129
+ llm_message = chat.answer(conv=chat_state,
130
+ img_list=img_list,
131
+ num_beams=num_beams,
132
+ temperature=temperature,
133
+ max_new_tokens=300,
134
+ max_length=2000)[0]
135
+ chatbot[-1][1] = llm_message
136
+ print(chat_state.get_prompt())
137
+ print(chat_state)
138
+ return chatbot, chat_state, img_list
139
+
140
+ title = """<h1 align="center">Demo of Video-LLaMA</h1>"""
141
+ description = """<h3>This is the demo of Video-LLaMA. Upload your images/videos and start chatting!</h3>"""
142
+
143
+
144
+ #TODO show examples below
145
+
146
+ with gr.Blocks() as demo:
147
+ gr.Markdown(title)
148
+ gr.Markdown(description)
149
+
150
+ with gr.Row():
151
+ with gr.Column(scale=0.5):
152
+ video = gr.Video()
153
+ image = gr.Image(type="pil")
154
+
155
+ upload_button = gr.Button(value="Upload & Start Chat", interactive=True, variant="primary")
156
+ clear = gr.Button("Restart")
157
+
158
+ num_beams = gr.Slider(
159
+ minimum=1,
160
+ maximum=10,
161
+ value=1,
162
+ step=1,
163
+ interactive=True,
164
+ label="beam search numbers)",
165
+ )
166
+
167
+ temperature = gr.Slider(
168
+ minimum=0.1,
169
+ maximum=2.0,
170
+ value=1.0,
171
+ step=0.1,
172
+ interactive=True,
173
+ label="Temperature",
174
+ )
175
+
176
+ with gr.Column():
177
+ chat_state = gr.State()
178
+ img_list = gr.State()
179
+ chatbot = gr.Chatbot(label='Video-LLaMA')
180
+ text_input = gr.Textbox(label='User', placeholder='Please upload your image/video first', interactive=False)
181
+
182
+
183
+ upload_button.click(upload_imgorvideo, [video, image, text_input, chat_state], [video, image, text_input, upload_button, chat_state, img_list])
184
+
185
+ text_input.submit(gradio_ask, [text_input, chatbot, chat_state], [text_input, chatbot, chat_state]).then(
186
+ gradio_answer, [chatbot, chat_state, img_list, num_beams, temperature], [chatbot, chat_state, img_list]
187
+ )
188
+ clear.click(gradio_reset, [chat_state, img_list], [chatbot, video, image, text_input, upload_button, chat_state, img_list], queue=False)
189
+
190
+ demo.launch(share=False, enable_queue=False)
191
+
192
+ # %%
ckpt/blip2_pretrained_flant5xxl.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4b3839ea6c617f315ead9bf4036bbb0f0cf6bf62695ecfc14968ea626af03a29
3
+ size 433481467
ckpt/finetune-vicuna7b-v2-nofrozen_imageQ.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fc4b32437c90df51bc3faa29deaa9b25ab77e1707ac79066f17ae3193ebe8bfc
3
+ size 1527692539
ckpt/finetune-vicuna7b-v2.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0680ad8eb14c2a3273b7be71309ab6b06c9f426e87ad4675a903371fe0fa8162
3
+ size 265436777
ckpt/pretrain-billa7b-zh.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f50a51db3055e1be6461f6dec833fbbbba28650287d26c8787664c8ee31dcf0f
3
+ size 265435689
eval_configs/video_llama_eval.yaml ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ arch: video_llama
3
+ model_type: pretrain_vicuna
4
+ freeze_vit: True
5
+ freeze_qformer: True
6
+ max_txt_len: 512
7
+ end_sym: "###"
8
+ low_resource: False
9
+
10
+
11
+ llama_model: "DAMO-NLP-SG/vicuna-7b"
12
+
13
+ fusion_head_layers: 2
14
+ max_frame_pos: 32
15
+ fusion_header_type: "seqTransf"
16
+
17
+ ckpt: 'ckpt/finetune-vicuna7b-v2-nofrozen_imageQ.pth'
18
+ q_former_model: 'ckpt/blip2_pretrained_flant5xxl.pth'
19
+
20
+ datasets:
21
+ webvid:
22
+ vis_processor:
23
+ train:
24
+ name: "alpro_video_eval"
25
+ n_frms: 8
26
+ image_size: 224
27
+ text_processor:
28
+ train:
29
+ name: "blip_caption"
30
+
31
+ run:
32
+ task: video_text_pretrain
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ transformers==4.28.0
2
+ tqdm
3
+ decord
4
+ timm
5
+ einops
6
+ opencv_python
7
+ torchvision
8
+
9
+ salesforce-lavis
10
+ bitsandbytes
11
+ accelerate
video_llama/__init__.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import os
9
+ import sys
10
+
11
+ from omegaconf import OmegaConf
12
+
13
+ from video_llama.common.registry import registry
14
+
15
+ from video_llama.datasets.builders import *
16
+ from video_llama.models import *
17
+ from video_llama.processors import *
18
+ from video_llama.tasks import *
19
+
20
+
21
+ root_dir = os.path.dirname(os.path.abspath(__file__))
22
+ default_cfg = OmegaConf.load(os.path.join(root_dir, "configs/default.yaml"))
23
+
24
+ registry.register_path("library_root", root_dir)
25
+ repo_root = os.path.join(root_dir, "..")
26
+ registry.register_path("repo_root", repo_root)
27
+ cache_root = os.path.join(repo_root, default_cfg.env.cache_root)
28
+ registry.register_path("cache_root", cache_root)
29
+
30
+ registry.register("MAX_INT", sys.maxsize)
31
+ registry.register("SPLIT_NAMES", ["train", "val", "test"])
video_llama/app.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Adapted from: https://github.com/Vision-CAIR/MiniGPT-4/blob/main/demo.py
3
+ """
4
+ import argparse
5
+ import os
6
+ import random
7
+
8
+ import numpy as np
9
+ import torch
10
+ import torch.backends.cudnn as cudnn
11
+ import gradio as gr
12
+
13
+ from video_llama.common.config import Config
14
+ from video_llama.common.dist_utils import get_rank
15
+ from video_llama.common.registry import registry
16
+ from video_llama.conversation.conversation_video import Chat, Conversation, default_conversation,SeparatorStyle
17
+ import decord
18
+ decord.bridge.set_bridge('torch')
19
+
20
+ #%%
21
+ # imports modules for registration
22
+ from video_llama.datasets.builders import *
23
+ from video_llama.models import *
24
+ from video_llama.processors import *
25
+ from video_llama.runners import *
26
+ from video_llama.tasks import *
27
+
28
+ #%%
29
+ def parse_args():
30
+ parser = argparse.ArgumentParser(description="Demo")
31
+ parser.add_argument("--cfg-path", default='eval_configs/video_llama_eval.yaml', help="path to configuration file.")
32
+ parser.add_argument("--gpu-id", type=int, default=0, help="specify the gpu to load the model.")
33
+ parser.add_argument(
34
+ "--options",
35
+ nargs="+",
36
+ help="override some settings in the used config, the key-value pair "
37
+ "in xxx=yyy format will be merged into config file (deprecate), "
38
+ "change to --cfg-options instead.",
39
+ )
40
+ args = parser.parse_args()
41
+ return args
42
+
43
+
44
+ def setup_seeds(config):
45
+ seed = config.run_cfg.seed + get_rank()
46
+
47
+ random.seed(seed)
48
+ np.random.seed(seed)
49
+ torch.manual_seed(seed)
50
+
51
+ cudnn.benchmark = False
52
+ cudnn.deterministic = True
53
+
54
+
55
+ # ========================================
56
+ # Model Initialization
57
+ # ========================================
58
+
59
+ print('Initializing Chat')
60
+ args = parse_args()
61
+ cfg = Config(args)
62
+
63
+ model_config = cfg.model_cfg
64
+ model_config.device_8bit = args.gpu_id
65
+ model_cls = registry.get_model_class(model_config.arch)
66
+ model = model_cls.from_config(model_config).to('cuda:{}'.format(args.gpu_id))
67
+
68
+ vis_processor_cfg = cfg.datasets_cfg.webvid.vis_processor.train
69
+ vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg)
70
+ chat = Chat(model, vis_processor, device='cuda:{}'.format(args.gpu_id))
71
+ print('Initialization Finished')
72
+
73
+ # ========================================
74
+ # Gradio Setting
75
+ # ========================================
76
+
77
+ def gradio_reset(chat_state, img_list):
78
+ if chat_state is not None:
79
+ chat_state.messages = []
80
+ if img_list is not None:
81
+ img_list = []
82
+ return None, gr.update(value=None, interactive=True), gr.update(value=None, interactive=True), gr.update(placeholder='Please upload your video first', interactive=False),gr.update(value="Upload & Start Chat", interactive=True), chat_state, img_list
83
+
84
+ def upload_imgorvideo(gr_video, gr_img, text_input, chat_state):
85
+ if gr_img is None and gr_video is None:
86
+ return None, None, None, gr.update(interactive=True), chat_state, None
87
+ elif gr_img is not None and gr_video is None:
88
+ print(gr_img)
89
+ chat_state = Conversation(
90
+ system= "You are able to understand the visual content that the user provides."
91
+ "Follow the instructions carefully and explain your answers in detail.",
92
+ roles=("Human", "Assistant"),
93
+ messages=[],
94
+ offset=0,
95
+ sep_style=SeparatorStyle.SINGLE,
96
+ sep="###",
97
+ )
98
+ img_list = []
99
+ llm_message = chat.upload_img(gr_img, chat_state, img_list)
100
+ return gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=True, placeholder='Type and press Enter'), gr.update(value="Start Chatting", interactive=False), chat_state, img_list
101
+ elif gr_video is not None and gr_img is None:
102
+ print(gr_video)
103
+ chat_state = default_conversation.copy()
104
+ chat_state = Conversation(
105
+ system= "You are able to understand the visual content that the user provides."
106
+ "Follow the instructions carefully and explain your answers in detail.",
107
+ roles=("Human", "Assistant"),
108
+ messages=[],
109
+ offset=0,
110
+ sep_style=SeparatorStyle.SINGLE,
111
+ sep="###",
112
+ )
113
+ img_list = []
114
+ llm_message = chat.upload_video(gr_video, chat_state, img_list)
115
+ return gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=True, placeholder='Type and press Enter'), gr.update(value="Start Chatting", interactive=False), chat_state, img_list
116
+ else:
117
+ # img_list = []
118
+ return gr.update(interactive=False), gr.update(interactive=False, placeholder='Currently, only one input is supported'), gr.update(value="Currently, only one input is supported", interactive=False), chat_state, None
119
+
120
+ def gradio_ask(user_message, chatbot, chat_state):
121
+ if len(user_message) == 0:
122
+ return gr.update(interactive=True, placeholder='Input should not be empty!'), chatbot, chat_state
123
+ chat.ask(user_message, chat_state)
124
+ chatbot = chatbot + [[user_message, None]]
125
+ return '', chatbot, chat_state
126
+
127
+
128
+ def gradio_answer(chatbot, chat_state, img_list, num_beams, temperature):
129
+ llm_message = chat.answer(conv=chat_state,
130
+ img_list=img_list,
131
+ num_beams=num_beams,
132
+ temperature=temperature,
133
+ max_new_tokens=300,
134
+ max_length=2000)[0]
135
+ chatbot[-1][1] = llm_message
136
+ print(chat_state.get_prompt())
137
+ print(chat_state)
138
+ return chatbot, chat_state, img_list
139
+
140
+ title = """<h1 align="center">Demo of Video-LLaMA</h1>"""
141
+ description = """<h3>This is the demo of Video-LLaMA. Upload your images/videos and start chatting!</h3>"""
142
+
143
+
144
+ #TODO show examples below
145
+
146
+ with gr.Blocks() as demo:
147
+ gr.Markdown(title)
148
+ gr.Markdown(description)
149
+
150
+ with gr.Row():
151
+ with gr.Column(scale=0.5):
152
+ video = gr.Video()
153
+ image = gr.Image(type="pil")
154
+
155
+ upload_button = gr.Button(value="Upload & Start Chat", interactive=True, variant="primary")
156
+ clear = gr.Button("Restart")
157
+
158
+ num_beams = gr.Slider(
159
+ minimum=1,
160
+ maximum=10,
161
+ value=1,
162
+ step=1,
163
+ interactive=True,
164
+ label="beam search numbers)",
165
+ )
166
+
167
+ temperature = gr.Slider(
168
+ minimum=0.1,
169
+ maximum=2.0,
170
+ value=1.0,
171
+ step=0.1,
172
+ interactive=True,
173
+ label="Temperature",
174
+ )
175
+
176
+ with gr.Column():
177
+ chat_state = gr.State()
178
+ img_list = gr.State()
179
+ chatbot = gr.Chatbot(label='Video-LLaMA')
180
+ text_input = gr.Textbox(label='User', placeholder='Please upload your image/video first', interactive=False)
181
+
182
+
183
+ upload_button.click(upload_imgorvideo, [video, image, text_input, chat_state], [video, image, text_input, upload_button, chat_state, img_list])
184
+
185
+ text_input.submit(gradio_ask, [text_input, chatbot, chat_state], [text_input, chatbot, chat_state]).then(
186
+ gradio_answer, [chatbot, chat_state, img_list, num_beams, temperature], [chatbot, chat_state, img_list]
187
+ )
188
+ clear.click(gradio_reset, [chat_state, img_list], [chatbot, video, image, text_input, upload_button, chat_state, img_list], queue=False)
189
+
190
+ demo.launch(share=False, enable_queue=False)
191
+
192
+ # %%
video_llama/ckpt/blip2_pretrained_flant5xxl.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4b3839ea6c617f315ead9bf4036bbb0f0cf6bf62695ecfc14968ea626af03a29
3
+ size 433481467
video_llama/ckpt/finetune-vicuna7b-v2-nofrozen_imageQ.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:46af76d307c14d28c56534e4bf8654343e5512aa1285fc1c1fdb5728c418e7ca
3
+ size 623104000
video_llama/ckpt/pretrain-billa7b-zh.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f50a51db3055e1be6461f6dec833fbbbba28650287d26c8787664c8ee31dcf0f
3
+ size 265435689
video_llama/common/__init__.py ADDED
File without changes
video_llama/common/config.py ADDED
@@ -0,0 +1,468 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import logging
9
+ import json
10
+ from typing import Dict
11
+
12
+ from omegaconf import OmegaConf
13
+ from video_llama.common.registry import registry
14
+
15
+
16
+ class Config:
17
+ def __init__(self, args):
18
+ self.config = {}
19
+
20
+ self.args = args
21
+
22
+ # Register the config and configuration for setup
23
+ registry.register("configuration", self)
24
+
25
+ user_config = self._build_opt_list(self.args.options)
26
+
27
+ config = OmegaConf.load(self.args.cfg_path)
28
+
29
+ runner_config = self.build_runner_config(config)
30
+ model_config = self.build_model_config(config, **user_config)
31
+ dataset_config = self.build_dataset_config(config)
32
+
33
+ # Validate the user-provided runner configuration
34
+ # model and dataset configuration are supposed to be validated by the respective classes
35
+ # [TODO] validate the model/dataset configuration
36
+ # self._validate_runner_config(runner_config)
37
+
38
+ # Override the default configuration with user options.
39
+ self.config = OmegaConf.merge(
40
+ runner_config, model_config, dataset_config, user_config
41
+ )
42
+
43
+ def _validate_runner_config(self, runner_config):
44
+ """
45
+ This method validates the configuration, such that
46
+ 1) all the user specified options are valid;
47
+ 2) no type mismatches between the user specified options and the config.
48
+ """
49
+ runner_config_validator = create_runner_config_validator()
50
+ runner_config_validator.validate(runner_config)
51
+
52
+ def _build_opt_list(self, opts):
53
+ opts_dot_list = self._convert_to_dot_list(opts)
54
+ return OmegaConf.from_dotlist(opts_dot_list)
55
+
56
+ @staticmethod
57
+ def build_model_config(config, **kwargs):
58
+ model = config.get("model", None)
59
+ assert model is not None, "Missing model configuration file."
60
+
61
+ model_cls = registry.get_model_class(model.arch)
62
+ assert model_cls is not None, f"Model '{model.arch}' has not been registered."
63
+
64
+ model_type = kwargs.get("model.model_type", None)
65
+ if not model_type:
66
+ model_type = model.get("model_type", None)
67
+ # else use the model type selected by user.
68
+
69
+ assert model_type is not None, "Missing model_type."
70
+
71
+ model_config_path = model_cls.default_config_path(model_type=model_type)
72
+
73
+ model_config = OmegaConf.create()
74
+ # hierarchy override, customized config > default config
75
+ model_config = OmegaConf.merge(
76
+ model_config,
77
+ OmegaConf.load(model_config_path),
78
+ {"model": config["model"]},
79
+ )
80
+
81
+ return model_config
82
+
83
+ @staticmethod
84
+ def build_runner_config(config):
85
+ return {"run": config.run}
86
+
87
+ @staticmethod
88
+ def build_dataset_config(config):
89
+ datasets = config.get("datasets", None)
90
+ if datasets is None:
91
+ raise KeyError(
92
+ "Expecting 'datasets' as the root key for dataset configuration."
93
+ )
94
+
95
+ dataset_config = OmegaConf.create()
96
+
97
+ for dataset_name in datasets:
98
+ builder_cls = registry.get_builder_class(dataset_name)
99
+
100
+ dataset_config_type = datasets[dataset_name].get("type", "default")
101
+ dataset_config_path = builder_cls.default_config_path(
102
+ type=dataset_config_type
103
+ )
104
+
105
+ # hierarchy override, customized config > default config
106
+ dataset_config = OmegaConf.merge(
107
+ dataset_config,
108
+ OmegaConf.load(dataset_config_path),
109
+ {"datasets": {dataset_name: config["datasets"][dataset_name]}},
110
+ )
111
+
112
+ return dataset_config
113
+
114
+ def _convert_to_dot_list(self, opts):
115
+ if opts is None:
116
+ opts = []
117
+
118
+ if len(opts) == 0:
119
+ return opts
120
+
121
+ has_equal = opts[0].find("=") != -1
122
+
123
+ if has_equal:
124
+ return opts
125
+
126
+ return [(opt + "=" + value) for opt, value in zip(opts[0::2], opts[1::2])]
127
+
128
+ def get_config(self):
129
+ return self.config
130
+
131
+ @property
132
+ def run_cfg(self):
133
+ return self.config.run
134
+
135
+ @property
136
+ def datasets_cfg(self):
137
+ return self.config.datasets
138
+
139
+ @property
140
+ def model_cfg(self):
141
+ return self.config.model
142
+
143
+ def pretty_print(self):
144
+ logging.info("\n===== Running Parameters =====")
145
+ logging.info(self._convert_node_to_json(self.config.run))
146
+
147
+ logging.info("\n====== Dataset Attributes ======")
148
+ datasets = self.config.datasets
149
+
150
+ for dataset in datasets:
151
+ if dataset in self.config.datasets:
152
+ logging.info(f"\n======== {dataset} =======")
153
+ dataset_config = self.config.datasets[dataset]
154
+ logging.info(self._convert_node_to_json(dataset_config))
155
+ else:
156
+ logging.warning(f"No dataset named '{dataset}' in config. Skipping")
157
+
158
+ logging.info(f"\n====== Model Attributes ======")
159
+ logging.info(self._convert_node_to_json(self.config.model))
160
+
161
+ def _convert_node_to_json(self, node):
162
+ container = OmegaConf.to_container(node, resolve=True)
163
+ return json.dumps(container, indent=4, sort_keys=True)
164
+
165
+ def to_dict(self):
166
+ return OmegaConf.to_container(self.config)
167
+
168
+
169
+ def node_to_dict(node):
170
+ return OmegaConf.to_container(node)
171
+
172
+
173
+ class ConfigValidator:
174
+ """
175
+ This is a preliminary implementation to centralize and validate the configuration.
176
+ May be altered in the future.
177
+
178
+ A helper class to validate configurations from yaml file.
179
+
180
+ This serves the following purposes:
181
+ 1. Ensure all the options in the yaml are defined, raise error if not.
182
+ 2. when type mismatches are found, the validator will raise an error.
183
+ 3. a central place to store and display helpful messages for supported configurations.
184
+
185
+ """
186
+
187
+ class _Argument:
188
+ def __init__(self, name, choices=None, type=None, help=None):
189
+ self.name = name
190
+ self.val = None
191
+ self.choices = choices
192
+ self.type = type
193
+ self.help = help
194
+
195
+ def __str__(self):
196
+ s = f"{self.name}={self.val}"
197
+ if self.type is not None:
198
+ s += f", ({self.type})"
199
+ if self.choices is not None:
200
+ s += f", choices: {self.choices}"
201
+ if self.help is not None:
202
+ s += f", ({self.help})"
203
+ return s
204
+
205
+ def __init__(self, description):
206
+ self.description = description
207
+
208
+ self.arguments = dict()
209
+
210
+ self.parsed_args = None
211
+
212
+ def __getitem__(self, key):
213
+ assert self.parsed_args is not None, "No arguments parsed yet."
214
+
215
+ return self.parsed_args[key]
216
+
217
+ def __str__(self) -> str:
218
+ return self.format_help()
219
+
220
+ def add_argument(self, *args, **kwargs):
221
+ """
222
+ Assume the first argument is the name of the argument.
223
+ """
224
+ self.arguments[args[0]] = self._Argument(*args, **kwargs)
225
+
226
+ def validate(self, config=None):
227
+ """
228
+ Convert yaml config (dict-like) to list, required by argparse.
229
+ """
230
+ for k, v in config.items():
231
+ assert (
232
+ k in self.arguments
233
+ ), f"""{k} is not a valid argument. Support arguments are {self.format_arguments()}."""
234
+
235
+ if self.arguments[k].type is not None:
236
+ try:
237
+ self.arguments[k].val = self.arguments[k].type(v)
238
+ except ValueError:
239
+ raise ValueError(f"{k} is not a valid {self.arguments[k].type}.")
240
+
241
+ if self.arguments[k].choices is not None:
242
+ assert (
243
+ v in self.arguments[k].choices
244
+ ), f"""{k} must be one of {self.arguments[k].choices}."""
245
+
246
+ return config
247
+
248
+ def format_arguments(self):
249
+ return str([f"{k}" for k in sorted(self.arguments.keys())])
250
+
251
+ def format_help(self):
252
+ # description + key-value pair string for each argument
253
+ help_msg = str(self.description)
254
+ return help_msg + ", available arguments: " + self.format_arguments()
255
+
256
+ def print_help(self):
257
+ # display help message
258
+ print(self.format_help())
259
+
260
+
261
+ def create_runner_config_validator():
262
+ validator = ConfigValidator(description="Runner configurations")
263
+
264
+ validator.add_argument(
265
+ "runner",
266
+ type=str,
267
+ choices=["runner_base", "runner_iter"],
268
+ help="""Runner to use. The "runner_base" uses epoch-based training while iter-based
269
+ runner runs based on iters. Default: runner_base""",
270
+ )
271
+ # add argumetns for training dataset ratios
272
+ validator.add_argument(
273
+ "train_dataset_ratios",
274
+ type=Dict[str, float],
275
+ help="""Ratios of training dataset. This is used in iteration-based runner.
276
+ Do not support for epoch-based runner because how to define an epoch becomes tricky.
277
+ Default: None""",
278
+ )
279
+ validator.add_argument(
280
+ "max_iters",
281
+ type=float,
282
+ help="Maximum number of iterations to run.",
283
+ )
284
+ validator.add_argument(
285
+ "max_epoch",
286
+ type=int,
287
+ help="Maximum number of epochs to run.",
288
+ )
289
+ # add arguments for iters_per_inner_epoch
290
+ validator.add_argument(
291
+ "iters_per_inner_epoch",
292
+ type=float,
293
+ help="Number of iterations per inner epoch. This is required when runner is runner_iter.",
294
+ )
295
+ lr_scheds_choices = registry.list_lr_schedulers()
296
+ validator.add_argument(
297
+ "lr_sched",
298
+ type=str,
299
+ choices=lr_scheds_choices,
300
+ help="Learning rate scheduler to use, from {}".format(lr_scheds_choices),
301
+ )
302
+ task_choices = registry.list_tasks()
303
+ validator.add_argument(
304
+ "task",
305
+ type=str,
306
+ choices=task_choices,
307
+ help="Task to use, from {}".format(task_choices),
308
+ )
309
+ # add arguments for init_lr
310
+ validator.add_argument(
311
+ "init_lr",
312
+ type=float,
313
+ help="Initial learning rate. This will be the learning rate after warmup and before decay.",
314
+ )
315
+ # add arguments for min_lr
316
+ validator.add_argument(
317
+ "min_lr",
318
+ type=float,
319
+ help="Minimum learning rate (after decay).",
320
+ )
321
+ # add arguments for warmup_lr
322
+ validator.add_argument(
323
+ "warmup_lr",
324
+ type=float,
325
+ help="Starting learning rate for warmup.",
326
+ )
327
+ # add arguments for learning rate decay rate
328
+ validator.add_argument(
329
+ "lr_decay_rate",
330
+ type=float,
331
+ help="Learning rate decay rate. Required if using a decaying learning rate scheduler.",
332
+ )
333
+ # add arguments for weight decay
334
+ validator.add_argument(
335
+ "weight_decay",
336
+ type=float,
337
+ help="Weight decay rate.",
338
+ )
339
+ # add arguments for training batch size
340
+ validator.add_argument(
341
+ "batch_size_train",
342
+ type=int,
343
+ help="Training batch size.",
344
+ )
345
+ # add arguments for evaluation batch size
346
+ validator.add_argument(
347
+ "batch_size_eval",
348
+ type=int,
349
+ help="Evaluation batch size, including validation and testing.",
350
+ )
351
+ # add arguments for number of workers for data loading
352
+ validator.add_argument(
353
+ "num_workers",
354
+ help="Number of workers for data loading.",
355
+ )
356
+ # add arguments for warm up steps
357
+ validator.add_argument(
358
+ "warmup_steps",
359
+ type=int,
360
+ help="Number of warmup steps. Required if a warmup schedule is used.",
361
+ )
362
+ # add arguments for random seed
363
+ validator.add_argument(
364
+ "seed",
365
+ type=int,
366
+ help="Random seed.",
367
+ )
368
+ # add arguments for output directory
369
+ validator.add_argument(
370
+ "output_dir",
371
+ type=str,
372
+ help="Output directory to save checkpoints and logs.",
373
+ )
374
+ # add arguments for whether only use evaluation
375
+ validator.add_argument(
376
+ "evaluate",
377
+ help="Whether to only evaluate the model. If true, training will not be performed.",
378
+ )
379
+ # add arguments for splits used for training, e.g. ["train", "val"]
380
+ validator.add_argument(
381
+ "train_splits",
382
+ type=list,
383
+ help="Splits to use for training.",
384
+ )
385
+ # add arguments for splits used for validation, e.g. ["val"]
386
+ validator.add_argument(
387
+ "valid_splits",
388
+ type=list,
389
+ help="Splits to use for validation. If not provided, will skip the validation.",
390
+ )
391
+ # add arguments for splits used for testing, e.g. ["test"]
392
+ validator.add_argument(
393
+ "test_splits",
394
+ type=list,
395
+ help="Splits to use for testing. If not provided, will skip the testing.",
396
+ )
397
+ # add arguments for accumulating gradient for iterations
398
+ validator.add_argument(
399
+ "accum_grad_iters",
400
+ type=int,
401
+ help="Number of iterations to accumulate gradient for.",
402
+ )
403
+
404
+ # ====== distributed training ======
405
+ validator.add_argument(
406
+ "device",
407
+ type=str,
408
+ choices=["cpu", "cuda"],
409
+ help="Device to use. Support 'cuda' or 'cpu' as for now.",
410
+ )
411
+ validator.add_argument(
412
+ "world_size",
413
+ type=int,
414
+ help="Number of processes participating in the job.",
415
+ )
416
+ validator.add_argument("dist_url", type=str)
417
+ validator.add_argument("distributed", type=bool)
418
+ # add arguments to opt using distributed sampler during evaluation or not
419
+ validator.add_argument(
420
+ "use_dist_eval_sampler",
421
+ type=bool,
422
+ help="Whether to use distributed sampler during evaluation or not.",
423
+ )
424
+
425
+ # ====== task specific ======
426
+ # generation task specific arguments
427
+ # add arguments for maximal length of text output
428
+ validator.add_argument(
429
+ "max_len",
430
+ type=int,
431
+ help="Maximal length of text output.",
432
+ )
433
+ # add arguments for minimal length of text output
434
+ validator.add_argument(
435
+ "min_len",
436
+ type=int,
437
+ help="Minimal length of text output.",
438
+ )
439
+ # add arguments number of beams
440
+ validator.add_argument(
441
+ "num_beams",
442
+ type=int,
443
+ help="Number of beams used for beam search.",
444
+ )
445
+
446
+ # vqa task specific arguments
447
+ # add arguments for number of answer candidates
448
+ validator.add_argument(
449
+ "num_ans_candidates",
450
+ type=int,
451
+ help="""For ALBEF and BLIP, these models first rank answers according to likelihood to select answer candidates.""",
452
+ )
453
+ # add arguments for inference method
454
+ validator.add_argument(
455
+ "inference_method",
456
+ type=str,
457
+ choices=["genearte", "rank"],
458
+ help="""Inference method to use for question answering. If rank, requires a answer list.""",
459
+ )
460
+
461
+ # ====== model specific ======
462
+ validator.add_argument(
463
+ "k_test",
464
+ type=int,
465
+ help="Number of top k most similar samples from ITC/VTC selection to be tested.",
466
+ )
467
+
468
+ return validator
video_llama/common/dist_utils.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import datetime
9
+ import functools
10
+ import os
11
+
12
+ import torch
13
+ import torch.distributed as dist
14
+ import timm.models.hub as timm_hub
15
+
16
+
17
+ def setup_for_distributed(is_master):
18
+ """
19
+ This function disables printing when not in master process
20
+ """
21
+ import builtins as __builtin__
22
+
23
+ builtin_print = __builtin__.print
24
+
25
+ def print(*args, **kwargs):
26
+ force = kwargs.pop("force", False)
27
+ if is_master or force:
28
+ builtin_print(*args, **kwargs)
29
+
30
+ __builtin__.print = print
31
+
32
+
33
+ def is_dist_avail_and_initialized():
34
+ if not dist.is_available():
35
+ return False
36
+ if not dist.is_initialized():
37
+ return False
38
+ return True
39
+
40
+
41
+ def get_world_size():
42
+ if not is_dist_avail_and_initialized():
43
+ return 1
44
+ return dist.get_world_size()
45
+
46
+
47
+ def get_rank():
48
+ if not is_dist_avail_and_initialized():
49
+ return 0
50
+ return dist.get_rank()
51
+
52
+
53
+ def is_main_process():
54
+ return get_rank() == 0
55
+
56
+
57
+ def init_distributed_mode(args):
58
+ if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
59
+ args.rank = int(os.environ["RANK"])
60
+ args.world_size = int(os.environ["WORLD_SIZE"])
61
+ args.gpu = int(os.environ["LOCAL_RANK"])
62
+ elif "SLURM_PROCID" in os.environ:
63
+ args.rank = int(os.environ["SLURM_PROCID"])
64
+ args.gpu = args.rank % torch.cuda.device_count()
65
+ else:
66
+ print("Not using distributed mode")
67
+ args.distributed = False
68
+ return
69
+
70
+ args.distributed = True
71
+
72
+ torch.cuda.set_device(args.gpu)
73
+ args.dist_backend = "nccl"
74
+ print(
75
+ "| distributed init (rank {}, world {}): {}".format(
76
+ args.rank, args.world_size, args.dist_url
77
+ ),
78
+ flush=True,
79
+ )
80
+ torch.distributed.init_process_group(
81
+ backend=args.dist_backend,
82
+ init_method=args.dist_url,
83
+ world_size=args.world_size,
84
+ rank=args.rank,
85
+ timeout=datetime.timedelta(
86
+ days=365
87
+ ), # allow auto-downloading and de-compressing
88
+ )
89
+ torch.distributed.barrier()
90
+ setup_for_distributed(args.rank == 0)
91
+
92
+
93
+ def get_dist_info():
94
+ if torch.__version__ < "1.0":
95
+ initialized = dist._initialized
96
+ else:
97
+ initialized = dist.is_initialized()
98
+ if initialized:
99
+ rank = dist.get_rank()
100
+ world_size = dist.get_world_size()
101
+ else: # non-distributed training
102
+ rank = 0
103
+ world_size = 1
104
+ return rank, world_size
105
+
106
+
107
+ def main_process(func):
108
+ @functools.wraps(func)
109
+ def wrapper(*args, **kwargs):
110
+ rank, _ = get_dist_info()
111
+ if rank == 0:
112
+ return func(*args, **kwargs)
113
+
114
+ return wrapper
115
+
116
+
117
+ def download_cached_file(url, check_hash=True, progress=False):
118
+ """
119
+ Download a file from a URL and cache it locally. If the file already exists, it is not downloaded again.
120
+ If distributed, only the main process downloads the file, and the other processes wait for the file to be downloaded.
121
+ """
122
+
123
+ def get_cached_file_path():
124
+ # a hack to sync the file path across processes
125
+ parts = torch.hub.urlparse(url)
126
+ filename = os.path.basename(parts.path)
127
+ cached_file = os.path.join(timm_hub.get_cache_dir(), filename)
128
+
129
+ return cached_file
130
+
131
+ if is_main_process():
132
+ timm_hub.download_cached_file(url, check_hash, progress)
133
+
134
+ if is_dist_avail_and_initialized():
135
+ dist.barrier()
136
+
137
+ return get_cached_file_path()
video_llama/common/gradcam.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from matplotlib import pyplot as plt
3
+ from scipy.ndimage import filters
4
+ from skimage import transform as skimage_transform
5
+
6
+
7
+ def getAttMap(img, attMap, blur=True, overlap=True):
8
+ attMap -= attMap.min()
9
+ if attMap.max() > 0:
10
+ attMap /= attMap.max()
11
+ attMap = skimage_transform.resize(attMap, (img.shape[:2]), order=3, mode="constant")
12
+ if blur:
13
+ attMap = filters.gaussian_filter(attMap, 0.02 * max(img.shape[:2]))
14
+ attMap -= attMap.min()
15
+ attMap /= attMap.max()
16
+ cmap = plt.get_cmap("jet")
17
+ attMapV = cmap(attMap)
18
+ attMapV = np.delete(attMapV, 3, 2)
19
+ if overlap:
20
+ attMap = (
21
+ 1 * (1 - attMap**0.7).reshape(attMap.shape + (1,)) * img
22
+ + (attMap**0.7).reshape(attMap.shape + (1,)) * attMapV
23
+ )
24
+ return attMap
video_llama/common/logger.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import datetime
9
+ import logging
10
+ import time
11
+ from collections import defaultdict, deque
12
+
13
+ import torch
14
+ import torch.distributed as dist
15
+
16
+ from video_llama.common import dist_utils
17
+
18
+
19
+ class SmoothedValue(object):
20
+ """Track a series of values and provide access to smoothed values over a
21
+ window or the global series average.
22
+ """
23
+
24
+ def __init__(self, window_size=20, fmt=None):
25
+ if fmt is None:
26
+ fmt = "{median:.4f} ({global_avg:.4f})"
27
+ self.deque = deque(maxlen=window_size)
28
+ self.total = 0.0
29
+ self.count = 0
30
+ self.fmt = fmt
31
+
32
+ def update(self, value, n=1):
33
+ self.deque.append(value)
34
+ self.count += n
35
+ self.total += value * n
36
+
37
+ def synchronize_between_processes(self):
38
+ """
39
+ Warning: does not synchronize the deque!
40
+ """
41
+ if not dist_utils.is_dist_avail_and_initialized():
42
+ return
43
+ t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda")
44
+ dist.barrier()
45
+ dist.all_reduce(t)
46
+ t = t.tolist()
47
+ self.count = int(t[0])
48
+ self.total = t[1]
49
+
50
+ @property
51
+ def median(self):
52
+ d = torch.tensor(list(self.deque))
53
+ return d.median().item()
54
+
55
+ @property
56
+ def avg(self):
57
+ d = torch.tensor(list(self.deque), dtype=torch.float32)
58
+ return d.mean().item()
59
+
60
+ @property
61
+ def global_avg(self):
62
+ return self.total / self.count
63
+
64
+ @property
65
+ def max(self):
66
+ return max(self.deque)
67
+
68
+ @property
69
+ def value(self):
70
+ return self.deque[-1]
71
+
72
+ def __str__(self):
73
+ return self.fmt.format(
74
+ median=self.median,
75
+ avg=self.avg,
76
+ global_avg=self.global_avg,
77
+ max=self.max,
78
+ value=self.value,
79
+ )
80
+
81
+
82
+ class MetricLogger(object):
83
+ def __init__(self, delimiter="\t"):
84
+ self.meters = defaultdict(SmoothedValue)
85
+ self.delimiter = delimiter
86
+
87
+ def update(self, **kwargs):
88
+ for k, v in kwargs.items():
89
+ if isinstance(v, torch.Tensor):
90
+ v = v.item()
91
+ assert isinstance(v, (float, int))
92
+ self.meters[k].update(v)
93
+
94
+ def __getattr__(self, attr):
95
+ if attr in self.meters:
96
+ return self.meters[attr]
97
+ if attr in self.__dict__:
98
+ return self.__dict__[attr]
99
+ raise AttributeError(
100
+ "'{}' object has no attribute '{}'".format(type(self).__name__, attr)
101
+ )
102
+
103
+ def __str__(self):
104
+ loss_str = []
105
+ for name, meter in self.meters.items():
106
+ loss_str.append("{}: {}".format(name, str(meter)))
107
+ return self.delimiter.join(loss_str)
108
+
109
+ def global_avg(self):
110
+ loss_str = []
111
+ for name, meter in self.meters.items():
112
+ loss_str.append("{}: {:.4f}".format(name, meter.global_avg))
113
+ return self.delimiter.join(loss_str)
114
+
115
+ def synchronize_between_processes(self):
116
+ for meter in self.meters.values():
117
+ meter.synchronize_between_processes()
118
+
119
+ def add_meter(self, name, meter):
120
+ self.meters[name] = meter
121
+
122
+ def log_every(self, iterable, print_freq, header=None):
123
+ i = 0
124
+ if not header:
125
+ header = ""
126
+ start_time = time.time()
127
+ end = time.time()
128
+ iter_time = SmoothedValue(fmt="{avg:.4f}")
129
+ data_time = SmoothedValue(fmt="{avg:.4f}")
130
+ space_fmt = ":" + str(len(str(len(iterable)))) + "d"
131
+ log_msg = [
132
+ header,
133
+ "[{0" + space_fmt + "}/{1}]",
134
+ "eta: {eta}",
135
+ "{meters}",
136
+ "time: {time}",
137
+ "data: {data}",
138
+ ]
139
+ if torch.cuda.is_available():
140
+ log_msg.append("max mem: {memory:.0f}")
141
+ log_msg = self.delimiter.join(log_msg)
142
+ MB = 1024.0 * 1024.0
143
+ for obj in iterable:
144
+ data_time.update(time.time() - end)
145
+ yield obj
146
+ iter_time.update(time.time() - end)
147
+ if i % print_freq == 0 or i == len(iterable) - 1:
148
+ eta_seconds = iter_time.global_avg * (len(iterable) - i)
149
+ eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
150
+ if torch.cuda.is_available():
151
+ print(
152
+ log_msg.format(
153
+ i,
154
+ len(iterable),
155
+ eta=eta_string,
156
+ meters=str(self),
157
+ time=str(iter_time),
158
+ data=str(data_time),
159
+ memory=torch.cuda.max_memory_allocated() / MB,
160
+ )
161
+ )
162
+ else:
163
+ print(
164
+ log_msg.format(
165
+ i,
166
+ len(iterable),
167
+ eta=eta_string,
168
+ meters=str(self),
169
+ time=str(iter_time),
170
+ data=str(data_time),
171
+ )
172
+ )
173
+ i += 1
174
+ end = time.time()
175
+ total_time = time.time() - start_time
176
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
177
+ print(
178
+ "{} Total time: {} ({:.4f} s / it)".format(
179
+ header, total_time_str, total_time / len(iterable)
180
+ )
181
+ )
182
+
183
+
184
+ class AttrDict(dict):
185
+ def __init__(self, *args, **kwargs):
186
+ super(AttrDict, self).__init__(*args, **kwargs)
187
+ self.__dict__ = self
188
+
189
+
190
+ def setup_logger():
191
+ logging.basicConfig(
192
+ level=logging.INFO if dist_utils.is_main_process() else logging.WARN,
193
+ format="%(asctime)s [%(levelname)s] %(message)s",
194
+ handlers=[logging.StreamHandler()],
195
+ )
video_llama/common/optims.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import math
9
+
10
+ from video_llama.common.registry import registry
11
+
12
+
13
+ @registry.register_lr_scheduler("linear_warmup_step_lr")
14
+ class LinearWarmupStepLRScheduler:
15
+ def __init__(
16
+ self,
17
+ optimizer,
18
+ max_epoch,
19
+ min_lr,
20
+ init_lr,
21
+ decay_rate=1,
22
+ warmup_start_lr=-1,
23
+ warmup_steps=0,
24
+ **kwargs
25
+ ):
26
+ self.optimizer = optimizer
27
+
28
+ self.max_epoch = max_epoch
29
+ self.min_lr = min_lr
30
+
31
+ self.decay_rate = decay_rate
32
+
33
+ self.init_lr = init_lr
34
+ self.warmup_steps = warmup_steps
35
+ self.warmup_start_lr = warmup_start_lr if warmup_start_lr >= 0 else init_lr
36
+
37
+ def step(self, cur_epoch, cur_step):
38
+ if cur_epoch == 0:
39
+ warmup_lr_schedule(
40
+ step=cur_step,
41
+ optimizer=self.optimizer,
42
+ max_step=self.warmup_steps,
43
+ init_lr=self.warmup_start_lr,
44
+ max_lr=self.init_lr,
45
+ )
46
+ else:
47
+ step_lr_schedule(
48
+ epoch=cur_epoch,
49
+ optimizer=self.optimizer,
50
+ init_lr=self.init_lr,
51
+ min_lr=self.min_lr,
52
+ decay_rate=self.decay_rate,
53
+ )
54
+
55
+
56
+ @registry.register_lr_scheduler("linear_warmup_cosine_lr")
57
+ class LinearWarmupCosineLRScheduler:
58
+ def __init__(
59
+ self,
60
+ optimizer,
61
+ max_epoch,
62
+ iters_per_epoch,
63
+ min_lr,
64
+ init_lr,
65
+ warmup_steps=0,
66
+ warmup_start_lr=-1,
67
+ **kwargs
68
+ ):
69
+ self.optimizer = optimizer
70
+
71
+ self.max_epoch = max_epoch
72
+ self.iters_per_epoch = iters_per_epoch
73
+ self.min_lr = min_lr
74
+
75
+ self.init_lr = init_lr
76
+ self.warmup_steps = warmup_steps
77
+ self.warmup_start_lr = warmup_start_lr if warmup_start_lr >= 0 else init_lr
78
+
79
+ def step(self, cur_epoch, cur_step):
80
+ total_cur_step = cur_epoch * self.iters_per_epoch + cur_step
81
+ if total_cur_step < self.warmup_steps:
82
+ warmup_lr_schedule(
83
+ step=cur_step,
84
+ optimizer=self.optimizer,
85
+ max_step=self.warmup_steps,
86
+ init_lr=self.warmup_start_lr,
87
+ max_lr=self.init_lr,
88
+ )
89
+ else:
90
+ cosine_lr_schedule(
91
+ epoch=total_cur_step,
92
+ optimizer=self.optimizer,
93
+ max_epoch=self.max_epoch * self.iters_per_epoch,
94
+ init_lr=self.init_lr,
95
+ min_lr=self.min_lr,
96
+ )
97
+
98
+
99
+ def cosine_lr_schedule(optimizer, epoch, max_epoch, init_lr, min_lr):
100
+ """Decay the learning rate"""
101
+ lr = (init_lr - min_lr) * 0.5 * (
102
+ 1.0 + math.cos(math.pi * epoch / max_epoch)
103
+ ) + min_lr
104
+ for param_group in optimizer.param_groups:
105
+ param_group["lr"] = lr
106
+
107
+
108
+ def warmup_lr_schedule(optimizer, step, max_step, init_lr, max_lr):
109
+ """Warmup the learning rate"""
110
+ lr = min(max_lr, init_lr + (max_lr - init_lr) * step / max(max_step, 1))
111
+ for param_group in optimizer.param_groups:
112
+ param_group["lr"] = lr
113
+
114
+
115
+ def step_lr_schedule(optimizer, epoch, init_lr, min_lr, decay_rate):
116
+ """Decay the learning rate"""
117
+ lr = max(min_lr, init_lr * (decay_rate**epoch))
118
+ for param_group in optimizer.param_groups:
119
+ param_group["lr"] = lr
video_llama/common/registry.py ADDED
@@ -0,0 +1,329 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+
9
+ class Registry:
10
+ mapping = {
11
+ "builder_name_mapping": {},
12
+ "task_name_mapping": {},
13
+ "processor_name_mapping": {},
14
+ "model_name_mapping": {},
15
+ "lr_scheduler_name_mapping": {},
16
+ "runner_name_mapping": {},
17
+ "state": {},
18
+ "paths": {},
19
+ }
20
+
21
+ @classmethod
22
+ def register_builder(cls, name):
23
+ r"""Register a dataset builder to registry with key 'name'
24
+
25
+ Args:
26
+ name: Key with which the builder will be registered.
27
+
28
+ Usage:
29
+
30
+ from video_llama.common.registry import registry
31
+ from video_llama.datasets.base_dataset_builder import BaseDatasetBuilder
32
+ """
33
+
34
+ def wrap(builder_cls):
35
+ from video_llama.datasets.builders.base_dataset_builder import BaseDatasetBuilder
36
+
37
+ assert issubclass(
38
+ builder_cls, BaseDatasetBuilder
39
+ ), "All builders must inherit BaseDatasetBuilder class, found {}".format(
40
+ builder_cls
41
+ )
42
+ if name in cls.mapping["builder_name_mapping"]:
43
+ raise KeyError(
44
+ "Name '{}' already registered for {}.".format(
45
+ name, cls.mapping["builder_name_mapping"][name]
46
+ )
47
+ )
48
+ cls.mapping["builder_name_mapping"][name] = builder_cls
49
+ return builder_cls
50
+
51
+ return wrap
52
+
53
+ @classmethod
54
+ def register_task(cls, name):
55
+ r"""Register a task to registry with key 'name'
56
+
57
+ Args:
58
+ name: Key with which the task will be registered.
59
+
60
+ Usage:
61
+
62
+ from video_llama.common.registry import registry
63
+ """
64
+
65
+ def wrap(task_cls):
66
+ from video_llama.tasks.base_task import BaseTask
67
+
68
+ assert issubclass(
69
+ task_cls, BaseTask
70
+ ), "All tasks must inherit BaseTask class"
71
+ if name in cls.mapping["task_name_mapping"]:
72
+ raise KeyError(
73
+ "Name '{}' already registered for {}.".format(
74
+ name, cls.mapping["task_name_mapping"][name]
75
+ )
76
+ )
77
+ cls.mapping["task_name_mapping"][name] = task_cls
78
+ return task_cls
79
+
80
+ return wrap
81
+
82
+ @classmethod
83
+ def register_model(cls, name):
84
+ r"""Register a task to registry with key 'name'
85
+
86
+ Args:
87
+ name: Key with which the task will be registered.
88
+
89
+ Usage:
90
+
91
+ from video_llama.common.registry import registry
92
+ """
93
+
94
+ def wrap(model_cls):
95
+ from video_llama.models import BaseModel
96
+
97
+ assert issubclass(
98
+ model_cls, BaseModel
99
+ ), "All models must inherit BaseModel class"
100
+ if name in cls.mapping["model_name_mapping"]:
101
+ raise KeyError(
102
+ "Name '{}' already registered for {}.".format(
103
+ name, cls.mapping["model_name_mapping"][name]
104
+ )
105
+ )
106
+ cls.mapping["model_name_mapping"][name] = model_cls
107
+ return model_cls
108
+
109
+ return wrap
110
+
111
+ @classmethod
112
+ def register_processor(cls, name):
113
+ r"""Register a processor to registry with key 'name'
114
+
115
+ Args:
116
+ name: Key with which the task will be registered.
117
+
118
+ Usage:
119
+
120
+ from video_llama.common.registry import registry
121
+ """
122
+
123
+ def wrap(processor_cls):
124
+ from video_llama.processors import BaseProcessor
125
+
126
+ assert issubclass(
127
+ processor_cls, BaseProcessor
128
+ ), "All processors must inherit BaseProcessor class"
129
+ if name in cls.mapping["processor_name_mapping"]:
130
+ raise KeyError(
131
+ "Name '{}' already registered for {}.".format(
132
+ name, cls.mapping["processor_name_mapping"][name]
133
+ )
134
+ )
135
+ cls.mapping["processor_name_mapping"][name] = processor_cls
136
+ return processor_cls
137
+
138
+ return wrap
139
+
140
+ @classmethod
141
+ def register_lr_scheduler(cls, name):
142
+ r"""Register a model to registry with key 'name'
143
+
144
+ Args:
145
+ name: Key with which the task will be registered.
146
+
147
+ Usage:
148
+
149
+ from video_llama.common.registry import registry
150
+ """
151
+
152
+ def wrap(lr_sched_cls):
153
+ if name in cls.mapping["lr_scheduler_name_mapping"]:
154
+ raise KeyError(
155
+ "Name '{}' already registered for {}.".format(
156
+ name, cls.mapping["lr_scheduler_name_mapping"][name]
157
+ )
158
+ )
159
+ cls.mapping["lr_scheduler_name_mapping"][name] = lr_sched_cls
160
+ return lr_sched_cls
161
+
162
+ return wrap
163
+
164
+ @classmethod
165
+ def register_runner(cls, name):
166
+ r"""Register a model to registry with key 'name'
167
+
168
+ Args:
169
+ name: Key with which the task will be registered.
170
+
171
+ Usage:
172
+
173
+ from video_llama.common.registry import registry
174
+ """
175
+
176
+ def wrap(runner_cls):
177
+ if name in cls.mapping["runner_name_mapping"]:
178
+ raise KeyError(
179
+ "Name '{}' already registered for {}.".format(
180
+ name, cls.mapping["runner_name_mapping"][name]
181
+ )
182
+ )
183
+ cls.mapping["runner_name_mapping"][name] = runner_cls
184
+ return runner_cls
185
+
186
+ return wrap
187
+
188
+ @classmethod
189
+ def register_path(cls, name, path):
190
+ r"""Register a path to registry with key 'name'
191
+
192
+ Args:
193
+ name: Key with which the path will be registered.
194
+
195
+ Usage:
196
+
197
+ from video_llama.common.registry import registry
198
+ """
199
+ assert isinstance(path, str), "All path must be str."
200
+ if name in cls.mapping["paths"]:
201
+ raise KeyError("Name '{}' already registered.".format(name))
202
+ cls.mapping["paths"][name] = path
203
+
204
+ @classmethod
205
+ def register(cls, name, obj):
206
+ r"""Register an item to registry with key 'name'
207
+
208
+ Args:
209
+ name: Key with which the item will be registered.
210
+
211
+ Usage::
212
+
213
+ from video_llama.common.registry import registry
214
+
215
+ registry.register("config", {})
216
+ """
217
+ path = name.split(".")
218
+ current = cls.mapping["state"]
219
+
220
+ for part in path[:-1]:
221
+ if part not in current:
222
+ current[part] = {}
223
+ current = current[part]
224
+
225
+ current[path[-1]] = obj
226
+
227
+ # @classmethod
228
+ # def get_trainer_class(cls, name):
229
+ # return cls.mapping["trainer_name_mapping"].get(name, None)
230
+
231
+ @classmethod
232
+ def get_builder_class(cls, name):
233
+ return cls.mapping["builder_name_mapping"].get(name, None)
234
+
235
+ @classmethod
236
+ def get_model_class(cls, name):
237
+ return cls.mapping["model_name_mapping"].get(name, None)
238
+
239
+ @classmethod
240
+ def get_task_class(cls, name):
241
+ return cls.mapping["task_name_mapping"].get(name, None)
242
+
243
+ @classmethod
244
+ def get_processor_class(cls, name):
245
+ return cls.mapping["processor_name_mapping"].get(name, None)
246
+
247
+ @classmethod
248
+ def get_lr_scheduler_class(cls, name):
249
+ return cls.mapping["lr_scheduler_name_mapping"].get(name, None)
250
+
251
+ @classmethod
252
+ def get_runner_class(cls, name):
253
+ return cls.mapping["runner_name_mapping"].get(name, None)
254
+
255
+ @classmethod
256
+ def list_runners(cls):
257
+ return sorted(cls.mapping["runner_name_mapping"].keys())
258
+
259
+ @classmethod
260
+ def list_models(cls):
261
+ return sorted(cls.mapping["model_name_mapping"].keys())
262
+
263
+ @classmethod
264
+ def list_tasks(cls):
265
+ return sorted(cls.mapping["task_name_mapping"].keys())
266
+
267
+ @classmethod
268
+ def list_processors(cls):
269
+ return sorted(cls.mapping["processor_name_mapping"].keys())
270
+
271
+ @classmethod
272
+ def list_lr_schedulers(cls):
273
+ return sorted(cls.mapping["lr_scheduler_name_mapping"].keys())
274
+
275
+ @classmethod
276
+ def list_datasets(cls):
277
+ return sorted(cls.mapping["builder_name_mapping"].keys())
278
+
279
+ @classmethod
280
+ def get_path(cls, name):
281
+ return cls.mapping["paths"].get(name, None)
282
+
283
+ @classmethod
284
+ def get(cls, name, default=None, no_warning=False):
285
+ r"""Get an item from registry with key 'name'
286
+
287
+ Args:
288
+ name (string): Key whose value needs to be retrieved.
289
+ default: If passed and key is not in registry, default value will
290
+ be returned with a warning. Default: None
291
+ no_warning (bool): If passed as True, warning when key doesn't exist
292
+ will not be generated. Useful for MMF's
293
+ internal operations. Default: False
294
+ """
295
+ original_name = name
296
+ name = name.split(".")
297
+ value = cls.mapping["state"]
298
+ for subname in name:
299
+ value = value.get(subname, default)
300
+ if value is default:
301
+ break
302
+
303
+ if (
304
+ "writer" in cls.mapping["state"]
305
+ and value == default
306
+ and no_warning is False
307
+ ):
308
+ cls.mapping["state"]["writer"].warning(
309
+ "Key {} is not present in registry, returning default value "
310
+ "of {}".format(original_name, default)
311
+ )
312
+ return value
313
+
314
+ @classmethod
315
+ def unregister(cls, name):
316
+ r"""Remove an item from registry with key 'name'
317
+
318
+ Args:
319
+ name: Key which needs to be removed.
320
+ Usage::
321
+
322
+ from mmf.common.registry import registry
323
+
324
+ config = registry.unregister("config")
325
+ """
326
+ return cls.mapping["state"].pop(name, None)
327
+
328
+
329
+ registry = Registry()
video_llama/common/utils.py ADDED
@@ -0,0 +1,424 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import io
9
+ import json
10
+ import logging
11
+ import os
12
+ import pickle
13
+ import re
14
+ import shutil
15
+ import urllib
16
+ import urllib.error
17
+ import urllib.request
18
+ from typing import Optional
19
+ from urllib.parse import urlparse
20
+
21
+ import numpy as np
22
+ import pandas as pd
23
+ import yaml
24
+ from iopath.common.download import download
25
+ from iopath.common.file_io import file_lock, g_pathmgr
26
+ from video_llama.common.registry import registry
27
+ from torch.utils.model_zoo import tqdm
28
+ from torchvision.datasets.utils import (
29
+ check_integrity,
30
+ download_file_from_google_drive,
31
+ extract_archive,
32
+ )
33
+
34
+
35
+ def now():
36
+ from datetime import datetime
37
+
38
+ return datetime.now().strftime("%Y%m%d%H%M")[:-1]
39
+
40
+
41
+ def is_url(url_or_filename):
42
+ parsed = urlparse(url_or_filename)
43
+ return parsed.scheme in ("http", "https")
44
+
45
+
46
+ def get_cache_path(rel_path):
47
+ return os.path.expanduser(os.path.join(registry.get_path("cache_root"), rel_path))
48
+
49
+
50
+ def get_abs_path(rel_path):
51
+ return os.path.join(registry.get_path("library_root"), rel_path)
52
+
53
+
54
+ def load_json(filename):
55
+ with open(filename, "r") as f:
56
+ return json.load(f)
57
+
58
+
59
+ # The following are adapted from torchvision and vissl
60
+ # torchvision: https://github.com/pytorch/vision
61
+ # vissl: https://github.com/facebookresearch/vissl/blob/main/vissl/utils/download.py
62
+
63
+
64
+ def makedir(dir_path):
65
+ """
66
+ Create the directory if it does not exist.
67
+ """
68
+ is_success = False
69
+ try:
70
+ if not g_pathmgr.exists(dir_path):
71
+ g_pathmgr.mkdirs(dir_path)
72
+ is_success = True
73
+ except BaseException:
74
+ print(f"Error creating directory: {dir_path}")
75
+ return is_success
76
+
77
+
78
+ def get_redirected_url(url: str):
79
+ """
80
+ Given a URL, returns the URL it redirects to or the
81
+ original URL in case of no indirection
82
+ """
83
+ import requests
84
+
85
+ with requests.Session() as session:
86
+ with session.get(url, stream=True, allow_redirects=True) as response:
87
+ if response.history:
88
+ return response.url
89
+ else:
90
+ return url
91
+
92
+
93
+ def to_google_drive_download_url(view_url: str) -> str:
94
+ """
95
+ Utility function to transform a view URL of google drive
96
+ to a download URL for google drive
97
+ Example input:
98
+ https://drive.google.com/file/d/137RyRjvTBkBiIfeYBNZBtViDHQ6_Ewsp/view
99
+ Example output:
100
+ https://drive.google.com/uc?export=download&id=137RyRjvTBkBiIfeYBNZBtViDHQ6_Ewsp
101
+ """
102
+ splits = view_url.split("/")
103
+ assert splits[-1] == "view"
104
+ file_id = splits[-2]
105
+ return f"https://drive.google.com/uc?export=download&id={file_id}"
106
+
107
+
108
+ def download_google_drive_url(url: str, output_path: str, output_file_name: str):
109
+ """
110
+ Download a file from google drive
111
+ Downloading an URL from google drive requires confirmation when
112
+ the file of the size is too big (google drive notifies that
113
+ anti-viral checks cannot be performed on such files)
114
+ """
115
+ import requests
116
+
117
+ with requests.Session() as session:
118
+
119
+ # First get the confirmation token and append it to the URL
120
+ with session.get(url, stream=True, allow_redirects=True) as response:
121
+ for k, v in response.cookies.items():
122
+ if k.startswith("download_warning"):
123
+ url = url + "&confirm=" + v
124
+
125
+ # Then download the content of the file
126
+ with session.get(url, stream=True, verify=True) as response:
127
+ makedir(output_path)
128
+ path = os.path.join(output_path, output_file_name)
129
+ total_size = int(response.headers.get("Content-length", 0))
130
+ with open(path, "wb") as file:
131
+ from tqdm import tqdm
132
+
133
+ with tqdm(total=total_size) as progress_bar:
134
+ for block in response.iter_content(
135
+ chunk_size=io.DEFAULT_BUFFER_SIZE
136
+ ):
137
+ file.write(block)
138
+ progress_bar.update(len(block))
139
+
140
+
141
+ def _get_google_drive_file_id(url: str) -> Optional[str]:
142
+ parts = urlparse(url)
143
+
144
+ if re.match(r"(drive|docs)[.]google[.]com", parts.netloc) is None:
145
+ return None
146
+
147
+ match = re.match(r"/file/d/(?P<id>[^/]*)", parts.path)
148
+ if match is None:
149
+ return None
150
+
151
+ return match.group("id")
152
+
153
+
154
+ def _urlretrieve(url: str, filename: str, chunk_size: int = 1024) -> None:
155
+ with open(filename, "wb") as fh:
156
+ with urllib.request.urlopen(
157
+ urllib.request.Request(url, headers={"User-Agent": "vissl"})
158
+ ) as response:
159
+ with tqdm(total=response.length) as pbar:
160
+ for chunk in iter(lambda: response.read(chunk_size), ""):
161
+ if not chunk:
162
+ break
163
+ pbar.update(chunk_size)
164
+ fh.write(chunk)
165
+
166
+
167
+ def download_url(
168
+ url: str,
169
+ root: str,
170
+ filename: Optional[str] = None,
171
+ md5: Optional[str] = None,
172
+ ) -> None:
173
+ """Download a file from a url and place it in root.
174
+ Args:
175
+ url (str): URL to download file from
176
+ root (str): Directory to place downloaded file in
177
+ filename (str, optional): Name to save the file under.
178
+ If None, use the basename of the URL.
179
+ md5 (str, optional): MD5 checksum of the download. If None, do not check
180
+ """
181
+ root = os.path.expanduser(root)
182
+ if not filename:
183
+ filename = os.path.basename(url)
184
+ fpath = os.path.join(root, filename)
185
+
186
+ makedir(root)
187
+
188
+ # check if file is already present locally
189
+ if check_integrity(fpath, md5):
190
+ print("Using downloaded and verified file: " + fpath)
191
+ return
192
+
193
+ # expand redirect chain if needed
194
+ url = get_redirected_url(url)
195
+
196
+ # check if file is located on Google Drive
197
+ file_id = _get_google_drive_file_id(url)
198
+ if file_id is not None:
199
+ return download_file_from_google_drive(file_id, root, filename, md5)
200
+
201
+ # download the file
202
+ try:
203
+ print("Downloading " + url + " to " + fpath)
204
+ _urlretrieve(url, fpath)
205
+ except (urllib.error.URLError, IOError) as e: # type: ignore[attr-defined]
206
+ if url[:5] == "https":
207
+ url = url.replace("https:", "http:")
208
+ print(
209
+ "Failed download. Trying https -> http instead."
210
+ " Downloading " + url + " to " + fpath
211
+ )
212
+ _urlretrieve(url, fpath)
213
+ else:
214
+ raise e
215
+
216
+ # check integrity of downloaded file
217
+ if not check_integrity(fpath, md5):
218
+ raise RuntimeError("File not found or corrupted.")
219
+
220
+
221
+ def download_and_extract_archive(
222
+ url: str,
223
+ download_root: str,
224
+ extract_root: Optional[str] = None,
225
+ filename: Optional[str] = None,
226
+ md5: Optional[str] = None,
227
+ remove_finished: bool = False,
228
+ ) -> None:
229
+ download_root = os.path.expanduser(download_root)
230
+ if extract_root is None:
231
+ extract_root = download_root
232
+ if not filename:
233
+ filename = os.path.basename(url)
234
+
235
+ download_url(url, download_root, filename, md5)
236
+
237
+ archive = os.path.join(download_root, filename)
238
+ print("Extracting {} to {}".format(archive, extract_root))
239
+ extract_archive(archive, extract_root, remove_finished)
240
+
241
+
242
+ def cache_url(url: str, cache_dir: str) -> str:
243
+ """
244
+ This implementation downloads the remote resource and caches it locally.
245
+ The resource will only be downloaded if not previously requested.
246
+ """
247
+ parsed_url = urlparse(url)
248
+ dirname = os.path.join(cache_dir, os.path.dirname(parsed_url.path.lstrip("/")))
249
+ makedir(dirname)
250
+ filename = url.split("/")[-1]
251
+ cached = os.path.join(dirname, filename)
252
+ with file_lock(cached):
253
+ if not os.path.isfile(cached):
254
+ logging.info(f"Downloading {url} to {cached} ...")
255
+ cached = download(url, dirname, filename=filename)
256
+ logging.info(f"URL {url} cached in {cached}")
257
+ return cached
258
+
259
+
260
+ # TODO (prigoyal): convert this into RAII-style API
261
+ def create_file_symlink(file1, file2):
262
+ """
263
+ Simply create the symlinks for a given file1 to file2.
264
+ Useful during model checkpointing to symlinks to the
265
+ latest successful checkpoint.
266
+ """
267
+ try:
268
+ if g_pathmgr.exists(file2):
269
+ g_pathmgr.rm(file2)
270
+ g_pathmgr.symlink(file1, file2)
271
+ except Exception as e:
272
+ logging.info(f"Could NOT create symlink. Error: {e}")
273
+
274
+
275
+ def save_file(data, filename, append_to_json=True, verbose=True):
276
+ """
277
+ Common i/o utility to handle saving data to various file formats.
278
+ Supported:
279
+ .pkl, .pickle, .npy, .json
280
+ Specifically for .json, users have the option to either append (default)
281
+ or rewrite by passing in Boolean value to append_to_json.
282
+ """
283
+ if verbose:
284
+ logging.info(f"Saving data to file: {filename}")
285
+ file_ext = os.path.splitext(filename)[1]
286
+ if file_ext in [".pkl", ".pickle"]:
287
+ with g_pathmgr.open(filename, "wb") as fopen:
288
+ pickle.dump(data, fopen, pickle.HIGHEST_PROTOCOL)
289
+ elif file_ext == ".npy":
290
+ with g_pathmgr.open(filename, "wb") as fopen:
291
+ np.save(fopen, data)
292
+ elif file_ext == ".json":
293
+ if append_to_json:
294
+ with g_pathmgr.open(filename, "a") as fopen:
295
+ fopen.write(json.dumps(data, sort_keys=True) + "\n")
296
+ fopen.flush()
297
+ else:
298
+ with g_pathmgr.open(filename, "w") as fopen:
299
+ fopen.write(json.dumps(data, sort_keys=True) + "\n")
300
+ fopen.flush()
301
+ elif file_ext == ".yaml":
302
+ with g_pathmgr.open(filename, "w") as fopen:
303
+ dump = yaml.dump(data)
304
+ fopen.write(dump)
305
+ fopen.flush()
306
+ else:
307
+ raise Exception(f"Saving {file_ext} is not supported yet")
308
+
309
+ if verbose:
310
+ logging.info(f"Saved data to file: {filename}")
311
+
312
+
313
+ def load_file(filename, mmap_mode=None, verbose=True, allow_pickle=False):
314
+ """
315
+ Common i/o utility to handle loading data from various file formats.
316
+ Supported:
317
+ .pkl, .pickle, .npy, .json
318
+ For the npy files, we support reading the files in mmap_mode.
319
+ If the mmap_mode of reading is not successful, we load data without the
320
+ mmap_mode.
321
+ """
322
+ if verbose:
323
+ logging.info(f"Loading data from file: {filename}")
324
+
325
+ file_ext = os.path.splitext(filename)[1]
326
+ if file_ext == ".txt":
327
+ with g_pathmgr.open(filename, "r") as fopen:
328
+ data = fopen.readlines()
329
+ elif file_ext in [".pkl", ".pickle"]:
330
+ with g_pathmgr.open(filename, "rb") as fopen:
331
+ data = pickle.load(fopen, encoding="latin1")
332
+ elif file_ext == ".npy":
333
+ if mmap_mode:
334
+ try:
335
+ with g_pathmgr.open(filename, "rb") as fopen:
336
+ data = np.load(
337
+ fopen,
338
+ allow_pickle=allow_pickle,
339
+ encoding="latin1",
340
+ mmap_mode=mmap_mode,
341
+ )
342
+ except ValueError as e:
343
+ logging.info(
344
+ f"Could not mmap {filename}: {e}. Trying without g_pathmgr"
345
+ )
346
+ data = np.load(
347
+ filename,
348
+ allow_pickle=allow_pickle,
349
+ encoding="latin1",
350
+ mmap_mode=mmap_mode,
351
+ )
352
+ logging.info("Successfully loaded without g_pathmgr")
353
+ except Exception:
354
+ logging.info("Could not mmap without g_pathmgr. Trying without mmap")
355
+ with g_pathmgr.open(filename, "rb") as fopen:
356
+ data = np.load(fopen, allow_pickle=allow_pickle, encoding="latin1")
357
+ else:
358
+ with g_pathmgr.open(filename, "rb") as fopen:
359
+ data = np.load(fopen, allow_pickle=allow_pickle, encoding="latin1")
360
+ elif file_ext == ".json":
361
+ with g_pathmgr.open(filename, "r") as fopen:
362
+ data = json.load(fopen)
363
+ elif file_ext == ".yaml":
364
+ with g_pathmgr.open(filename, "r") as fopen:
365
+ data = yaml.load(fopen, Loader=yaml.FullLoader)
366
+ elif file_ext == ".csv":
367
+ with g_pathmgr.open(filename, "r") as fopen:
368
+ data = pd.read_csv(fopen)
369
+ else:
370
+ raise Exception(f"Reading from {file_ext} is not supported yet")
371
+ return data
372
+
373
+
374
+ def abspath(resource_path: str):
375
+ """
376
+ Make a path absolute, but take into account prefixes like
377
+ "http://" or "manifold://"
378
+ """
379
+ regex = re.compile(r"^\w+://")
380
+ if regex.match(resource_path) is None:
381
+ return os.path.abspath(resource_path)
382
+ else:
383
+ return resource_path
384
+
385
+
386
+ def makedir(dir_path):
387
+ """
388
+ Create the directory if it does not exist.
389
+ """
390
+ is_success = False
391
+ try:
392
+ if not g_pathmgr.exists(dir_path):
393
+ g_pathmgr.mkdirs(dir_path)
394
+ is_success = True
395
+ except BaseException:
396
+ logging.info(f"Error creating directory: {dir_path}")
397
+ return is_success
398
+
399
+
400
+ def is_url(input_url):
401
+ """
402
+ Check if an input string is a url. look for http(s):// and ignoring the case
403
+ """
404
+ is_url = re.match(r"^(?:http)s?://", input_url, re.IGNORECASE) is not None
405
+ return is_url
406
+
407
+
408
+ def cleanup_dir(dir):
409
+ """
410
+ Utility for deleting a directory. Useful for cleaning the storage space
411
+ that contains various training artifacts like checkpoints, data etc.
412
+ """
413
+ if os.path.exists(dir):
414
+ logging.info(f"Deleting directory: {dir}")
415
+ shutil.rmtree(dir)
416
+ logging.info(f"Deleted contents of directory: {dir}")
417
+
418
+
419
+ def get_file_size(filename):
420
+ """
421
+ Given a file, get the size of file in MB
422
+ """
423
+ size_in_mb = os.path.getsize(filename) / float(1024**2)
424
+ return size_in_mb
video_llama/configs/datasets/cc_sbu/align.yaml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ datasets:
2
+ cc_sbu_align:
3
+ data_type: images
4
+ build_info:
5
+ storage: /path/to/cc_sbu_align_dataset
video_llama/configs/datasets/cc_sbu/defaults.yaml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ datasets:
2
+ cc_sbu:
3
+ data_type: images
4
+ build_info:
5
+ storage: /path/to/cc_sbu_dataset/{00000..00001}.tar
video_llama/configs/datasets/instruct/llava_instruct.yaml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ datasets:
2
+ llava_instruct:
3
+ data_type: image
4
+ build_info:
5
+ anno_dir: /path/llava_instruct_150k.json
6
+ videos_dir: /path/train2014/train2014/
video_llama/configs/datasets/instruct/webvid_instruct.yaml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ datasets:
2
+ webvid_instruct:
3
+ data_type: image
4
+ build_info:
5
+ anno_dir: /path/webvid_align/videochat_instruct_11k.json
6
+ videos_dir: /path/webvid_align/videos/
video_llama/configs/datasets/laion/defaults.yaml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ datasets:
2
+ laion:
3
+ data_type: images
4
+ build_info:
5
+ storage: path/laion/laion_dataset/{00000..00001}.tar
video_llama/configs/datasets/webvid/defaults.yaml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ datasets:
2
+ webvid:
3
+ data_type: video
4
+ build_info:
5
+ anno_dir: path/webvid/webvid_tain_data/annotations/
6
+ videos_dir: path//webvid/webvid_tain_data/videos/
video_llama/configs/default.yaml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ env:
2
+ # For default users
3
+ # cache_root: "cache"
4
+ # For internal use with persistent storage
5
+ cache_root: "/export/home/.cache/minigpt4"
video_llama/configs/models/minigpt4.yaml ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ arch: mini_gpt4
3
+
4
+ # vit encoder
5
+ image_size: 224
6
+ drop_path_rate: 0
7
+ use_grad_checkpoint: False
8
+ vit_precision: "fp16"
9
+ freeze_vit: True
10
+ freeze_qformer: True
11
+
12
+ # Q-Former
13
+ num_query_token: 32
14
+
15
+ # Vicuna
16
+ llama_model: "ckpt/vicuna-13b/"
17
+
18
+ # generation configs
19
+ prompt: ""
20
+
21
+ preprocess:
22
+ vis_processor:
23
+ train:
24
+ name: "blip2_image_train"
25
+ image_size: 224
26
+ eval:
27
+ name: "blip2_image_eval"
28
+ image_size: 224
29
+ text_processor:
30
+ train:
31
+ name: "blip_caption"
32
+ eval:
33
+ name: "blip_caption"
video_llama/configs/models/video_llama.yaml ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ arch: video_llama
3
+
4
+ # vit encoder
5
+ image_size: 224
6
+ drop_path_rate: 0
7
+ use_grad_checkpoint: False
8
+ vit_precision: "fp16"
9
+ freeze_vit: True
10
+ freeze_qformer: True
11
+
12
+ # Q-Former
13
+ num_query_token: 32
14
+
15
+ # Vicuna
16
+ llama_model: "ckpt/vicuna-7b/"
17
+
18
+ # generation configs
19
+ prompt: ""
20
+
21
+ preprocess:
22
+ vis_processor:
23
+ train:
24
+ name: "alpro_video_train"
25
+ image_size: 224
26
+ n_frms: 8
27
+ eval:
28
+ name: "alpro_video_eval"
29
+ image_size: 224
30
+ n_frms: 8
31
+ text_processor:
32
+ train:
33
+ name: "blip_caption"
34
+ eval:
35
+ name: "blip_caption"
36
+
video_llama/conversation/__init__.py ADDED
File without changes
video_llama/conversation/conversation_video.py ADDED
@@ -0,0 +1,248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Conversation prompt template of Video-LLaMA.
3
+ Adapted from: https://github.com/Vision-CAIR/MiniGPT-4/blob/main/minigpt4/conversation/conversation.py
4
+ """
5
+ import argparse
6
+ import time
7
+ from PIL import Image
8
+
9
+ import torch
10
+ from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaTokenizer
11
+ from transformers import StoppingCriteria, StoppingCriteriaList
12
+
13
+ import dataclasses
14
+ from enum import auto, Enum
15
+ from typing import List, Tuple, Any
16
+ import os
17
+ from video_llama.common.registry import registry
18
+ from video_llama.processors.video_processor import ToTHWC,ToUint8,load_video
19
+ from video_llama.processors import Blip2ImageEvalProcessor
20
+ class SeparatorStyle(Enum):
21
+ """Different separator style."""
22
+ SINGLE = auto()
23
+ TWO = auto()
24
+
25
+
26
+ @dataclasses.dataclass
27
+ class Conversation:
28
+ """A class that keeps all conversation history."""
29
+ system: str
30
+ roles: List[str]
31
+ messages: List[List[str]]
32
+ offset: int
33
+ # system_img: List[Image.Image] = []
34
+ sep_style: SeparatorStyle = SeparatorStyle.SINGLE
35
+ sep: str = "###"
36
+ sep2: str = None
37
+
38
+ skip_next: bool = False
39
+ conv_id: Any = None
40
+
41
+ def get_prompt(self):
42
+ if self.sep_style == SeparatorStyle.SINGLE:
43
+ ret = self.system + self.sep
44
+ for role, message in self.messages:
45
+ if message:
46
+ ret += role + ": " + message + self.sep
47
+ else:
48
+ ret += role + ":"
49
+ return ret
50
+ elif self.sep_style == SeparatorStyle.TWO:
51
+ seps = [self.sep, self.sep2]
52
+ ret = self.system + seps[0]
53
+ for i, (role, message) in enumerate(self.messages):
54
+ if message:
55
+ ret += role + ": " + message + seps[i % 2]
56
+ else:
57
+ ret += role + ":"
58
+ return ret
59
+ else:
60
+ raise ValueError(f"Invalid style: {self.sep_style}")
61
+
62
+ def append_message(self, role, message):
63
+ self.messages.append([role, message])
64
+
65
+ def to_gradio_chatbot(self):
66
+ ret = []
67
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
68
+ if i % 2 == 0:
69
+ ret.append([msg, None])
70
+ else:
71
+ ret[-1][-1] = msg
72
+ return ret
73
+
74
+ def copy(self):
75
+ return Conversation(
76
+ system=self.system,
77
+ # system_img=self.system_img,
78
+ roles=self.roles,
79
+ messages=[[x, y] for x, y in self.messages],
80
+ offset=self.offset,
81
+ sep_style=self.sep_style,
82
+ sep=self.sep,
83
+ sep2=self.sep2,
84
+ conv_id=self.conv_id)
85
+
86
+ def dict(self):
87
+ return {
88
+ "system": self.system,
89
+ # "system_img": self.system_img,
90
+ "roles": self.roles,
91
+ "messages": self.messages,
92
+ "offset": self.offset,
93
+ "sep": self.sep,
94
+ "sep2": self.sep2,
95
+ "conv_id": self.conv_id,
96
+ }
97
+
98
+
99
+ class StoppingCriteriaSub(StoppingCriteria):
100
+
101
+ def __init__(self, stops=[], encounters=1):
102
+ super().__init__()
103
+ self.stops = stops
104
+
105
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
106
+ for stop in self.stops:
107
+ if torch.all((stop == input_ids[0][-len(stop):])).item():
108
+ return True
109
+
110
+ return False
111
+
112
+
113
+ CONV_VISION = Conversation(
114
+ system="Give the following image: <Img>ImageContent</Img>. "
115
+ "You will be able to see the image once I provide it to you. Please answer my questions.",
116
+ roles=("Human", "Assistant"),
117
+ messages=[],
118
+ offset=0,
119
+ sep_style=SeparatorStyle.SINGLE,
120
+ sep="###",
121
+ )
122
+
123
+ default_conversation = Conversation(
124
+ system="",
125
+ roles=("Human", "Assistant"),
126
+ messages=[],
127
+ offset=0,
128
+ sep_style=SeparatorStyle.SINGLE,
129
+ sep="###",
130
+ )
131
+
132
+ class Chat:
133
+ def __init__(self, model, vis_processor, device='cuda:0'):
134
+ self.device = device
135
+ self.model = model
136
+ self.vis_processor = vis_processor
137
+ self.image_vis_processor = Blip2ImageEvalProcessor()
138
+ stop_words_ids = [torch.tensor([835]).to(self.device),
139
+ torch.tensor([2277, 29937]).to(self.device)] # '###' can be encoded in two different ways.
140
+ self.stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
141
+
142
+ def ask(self, text, conv):
143
+ if len(conv.messages) > 0 and conv.messages[-1][0] == conv.roles[0] \
144
+ and ('</Video>' in conv.messages[-1][1] or '</Image>' in conv.messages[-1][1]): # last message is image.
145
+ conv.messages[-1][1] = ' '.join([conv.messages[-1][1], text])
146
+ else:
147
+ conv.append_message(conv.roles[0], text)
148
+
149
+ def answer(self, conv, img_list, max_new_tokens=300, num_beams=1, min_length=1, top_p=0.9,
150
+ repetition_penalty=1.0, length_penalty=1, temperature=1.0, max_length=2000):
151
+ conv.append_message(conv.roles[1], None)
152
+ embs = self.get_context_emb(conv, img_list)
153
+
154
+ current_max_len = embs.shape[1] + max_new_tokens
155
+ if current_max_len - max_length > 0:
156
+ print('Warning: The number of tokens in current conversation exceeds the max length. '
157
+ 'The model will not see the contexts outside the range.')
158
+ begin_idx = max(0, current_max_len - max_length)
159
+
160
+ embs = embs[:, begin_idx:]
161
+
162
+ outputs = self.model.llama_model.generate(
163
+ inputs_embeds=embs,
164
+ max_new_tokens=max_new_tokens,
165
+ stopping_criteria=self.stopping_criteria,
166
+ num_beams=num_beams,
167
+ do_sample=True,
168
+ min_length=min_length,
169
+ top_p=top_p,
170
+ repetition_penalty=repetition_penalty,
171
+ length_penalty=length_penalty,
172
+ temperature=temperature,
173
+ )
174
+ output_token = outputs[0]
175
+ if output_token[0] == 0: # the model might output a unknow token <unk> at the beginning. remove it
176
+ output_token = output_token[1:]
177
+ if output_token[0] == 1: # some users find that there is a start token <s> at the beginning. remove it
178
+ output_token = output_token[1:]
179
+ output_text = self.model.llama_tokenizer.decode(output_token, add_special_tokens=False)
180
+ output_text = output_text.split('###')[0] # remove the stop sign '###'
181
+ output_text = output_text.split('Assistant:')[-1].strip()
182
+ conv.messages[-1][1] = output_text
183
+ return output_text, output_token.cpu().numpy()
184
+
185
+ def upload_video(self, video, conv, img_list):
186
+
187
+ msg = ""
188
+ if isinstance(video, str): # is a video path
189
+ ext = os.path.splitext(video)[-1].lower()
190
+ print(video)
191
+ # image = self.vis_processor(image).unsqueeze(0).to(self.device)
192
+ video, msg = load_video(
193
+ video_path=video,
194
+ n_frms=8,
195
+ height=224,
196
+ width=224,
197
+ sampling ="uniform", return_msg = True
198
+ )
199
+ video = self.vis_processor.transform(video)
200
+ video = video.unsqueeze(0).to(self.device)
201
+ # print(image)
202
+ else:
203
+ raise NotImplementedError
204
+
205
+ image_emb, _ = self.model.encode_img(video)
206
+ img_list.append(image_emb)
207
+ conv.append_message(conv.roles[0], "<Video><ImageHere></Video> "+ msg)
208
+ return "Received."
209
+
210
+ def upload_img(self, image, conv, img_list):
211
+
212
+ msg = ""
213
+ if isinstance(image, str): # is a image path
214
+ raw_image = Image.open(image).convert('RGB') # 增加一个时间维度
215
+ image = self.image_vis_processor(raw_image).unsqueeze(0).unsqueeze(2).to(self.device)
216
+ elif isinstance(image, Image.Image):
217
+ raw_image = image
218
+ image = self.image_vis_processor(raw_image).unsqueeze(0).unsqueeze(2).to(self.device)
219
+ elif isinstance(image, torch.Tensor):
220
+ if len(image.shape) == 3:
221
+ image = image.unsqueeze(0)
222
+ image = image.to(self.device)
223
+ else:
224
+ raise NotImplementedError
225
+
226
+ image_emb, _ = self.model.encode_img(image)
227
+ img_list.append(image_emb)
228
+ # Todo msg=""
229
+ conv.append_message(conv.roles[0], "<Image><ImageHere></Image> "+ msg)
230
+
231
+ return "Received."
232
+
233
+ def get_context_emb(self, conv, img_list):
234
+ prompt = conv.get_prompt()
235
+ prompt_segs = prompt.split('<ImageHere>')
236
+ assert len(prompt_segs) == len(img_list) + 1, "Unmatched numbers of image placeholders and images."
237
+ seg_tokens = [
238
+ self.model.llama_tokenizer(
239
+ seg, return_tensors="pt", add_special_tokens=i == 0).to(self.device).input_ids
240
+ # only add bos to the first seg
241
+ for i, seg in enumerate(prompt_segs)
242
+ ]
243
+ seg_embs = [self.model.llama_model.model.embed_tokens(seg_t) for seg_t in seg_tokens]
244
+ mixed_embs = [emb for pair in zip(seg_embs[:-1], img_list) for emb in pair] + [seg_embs[-1]]
245
+ mixed_embs = torch.cat(mixed_embs, dim=1)
246
+ return mixed_embs
247
+
248
+
video_llama/datasets/__init__.py ADDED
File without changes
video_llama/datasets/builders/__init__.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ from video_llama.datasets.builders.base_dataset_builder import load_dataset_config
9
+ from video_llama.datasets.builders.image_text_pair_builder import (
10
+ CCSBUBuilder,
11
+ LaionBuilder,
12
+ CCSBUAlignBuilder
13
+ )
14
+ from video_llama.datasets.builders.video_caption_builder import WebvidBuilder
15
+ from video_llama.common.registry import registry
16
+ from video_llama.datasets.builders.instruct_builder import WebvidInstruct_Builder,LlavaInstruct_Builder
17
+ __all__ = [
18
+ "CCSBUBuilder",
19
+ "LaionBuilder",
20
+ "CCSBUAlignBuilder",
21
+ "WebvidBuilder",
22
+ "LlavaInstruct_Builder",
23
+ "WebvidInstruct_Builder"
24
+
25
+ ]
26
+
27
+
28
+ def load_dataset(name, cfg_path=None, vis_path=None, data_type=None):
29
+ """
30
+ Example
31
+
32
+ >>> dataset = load_dataset("coco_caption", cfg=None)
33
+ >>> splits = dataset.keys()
34
+ >>> print([len(dataset[split]) for split in splits])
35
+
36
+ """
37
+ if cfg_path is None:
38
+ cfg = None
39
+ else:
40
+ cfg = load_dataset_config(cfg_path)
41
+
42
+ try:
43
+ builder = registry.get_builder_class(name)(cfg)
44
+ except TypeError:
45
+ print(
46
+ f"Dataset {name} not found. Available datasets:\n"
47
+ + ", ".join([str(k) for k in dataset_zoo.get_names()])
48
+ )
49
+ exit(1)
50
+
51
+ if vis_path is not None:
52
+ if data_type is None:
53
+ # use default data type in the config
54
+ data_type = builder.config.data_type
55
+
56
+ assert (
57
+ data_type in builder.config.build_info
58
+ ), f"Invalid data_type {data_type} for {name}."
59
+
60
+ builder.config.build_info.get(data_type).storage = vis_path
61
+
62
+ dataset = builder.build_datasets()
63
+ return dataset
64
+
65
+
66
+ class DatasetZoo:
67
+ def __init__(self) -> None:
68
+ self.dataset_zoo = {
69
+ k: list(v.DATASET_CONFIG_DICT.keys())
70
+ for k, v in sorted(registry.mapping["builder_name_mapping"].items())
71
+ }
72
+
73
+ def get_names(self):
74
+ return list(self.dataset_zoo.keys())
75
+
76
+
77
+ dataset_zoo = DatasetZoo()
video_llama/datasets/builders/base_dataset_builder.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This file is from
3
+ Copyright (c) 2022, salesforce.com, inc.
4
+ All rights reserved.
5
+ SPDX-License-Identifier: BSD-3-Clause
6
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
7
+ """
8
+
9
+ import logging
10
+ import os
11
+ import shutil
12
+ import warnings
13
+
14
+ from omegaconf import OmegaConf
15
+ import torch.distributed as dist
16
+ from torchvision.datasets.utils import download_url
17
+
18
+ import video_llama.common.utils as utils
19
+ from video_llama.common.dist_utils import is_dist_avail_and_initialized, is_main_process
20
+ from video_llama.common.registry import registry
21
+ from video_llama.processors.base_processor import BaseProcessor
22
+
23
+
24
+
25
+ class BaseDatasetBuilder:
26
+ train_dataset_cls, eval_dataset_cls = None, None
27
+
28
+ def __init__(self, cfg=None):
29
+ super().__init__()
30
+
31
+ if cfg is None:
32
+ # help to create datasets from default config.
33
+ self.config = load_dataset_config(self.default_config_path())
34
+ elif isinstance(cfg, str):
35
+ self.config = load_dataset_config(cfg)
36
+ else:
37
+ # when called from task.build_dataset()
38
+ self.config = cfg
39
+
40
+ self.data_type = self.config.data_type
41
+
42
+ self.vis_processors = {"train": BaseProcessor(), "eval": BaseProcessor()}
43
+ self.text_processors = {"train": BaseProcessor(), "eval": BaseProcessor()}
44
+
45
+ def build_datasets(self):
46
+ # download, split, etc...
47
+ # only called on 1 GPU/TPU in distributed
48
+
49
+ if is_main_process():
50
+ self._download_data()
51
+
52
+ if is_dist_avail_and_initialized():
53
+ dist.barrier()
54
+
55
+ # at this point, all the annotations and image/videos should be all downloaded to the specified locations.
56
+ logging.info("Building datasets...")
57
+ datasets = self.build() # dataset['train'/'val'/'test']
58
+
59
+ return datasets
60
+
61
+ def build_processors(self):
62
+ vis_proc_cfg = self.config.get("vis_processor")
63
+ txt_proc_cfg = self.config.get("text_processor")
64
+
65
+ if vis_proc_cfg is not None:
66
+ vis_train_cfg = vis_proc_cfg.get("train")
67
+ vis_eval_cfg = vis_proc_cfg.get("eval")
68
+
69
+ self.vis_processors["train"] = self._build_proc_from_cfg(vis_train_cfg)
70
+ self.vis_processors["eval"] = self._build_proc_from_cfg(vis_eval_cfg)
71
+
72
+ if txt_proc_cfg is not None:
73
+ txt_train_cfg = txt_proc_cfg.get("train")
74
+ txt_eval_cfg = txt_proc_cfg.get("eval")
75
+
76
+ self.text_processors["train"] = self._build_proc_from_cfg(txt_train_cfg)
77
+ self.text_processors["eval"] = self._build_proc_from_cfg(txt_eval_cfg)
78
+
79
+ @staticmethod
80
+ def _build_proc_from_cfg(cfg):
81
+ return (
82
+ registry.get_processor_class(cfg.name).from_config(cfg)
83
+ if cfg is not None
84
+ else None
85
+ )
86
+
87
+ @classmethod
88
+ def default_config_path(cls, type="default"):
89
+ return utils.get_abs_path(cls.DATASET_CONFIG_DICT[type])
90
+
91
+ def _download_data(self):
92
+ self._download_ann()
93
+ self._download_vis()
94
+
95
+ def _download_ann(self):
96
+ """
97
+ Download annotation files if necessary.
98
+ All the vision-language datasets should have annotations of unified format.
99
+
100
+ storage_path can be:
101
+ (1) relative/absolute: will be prefixed with env.cache_root to make full path if relative.
102
+ (2) basename/dirname: will be suffixed with base name of URL if dirname is provided.
103
+
104
+ Local annotation paths should be relative.
105
+ """
106
+ anns = self.config.build_info.annotations
107
+
108
+ splits = anns.keys()
109
+
110
+ cache_root = registry.get_path("cache_root")
111
+
112
+ for split in splits:
113
+ info = anns[split]
114
+
115
+ urls, storage_paths = info.get("url", None), info.storage
116
+
117
+ if isinstance(urls, str):
118
+ urls = [urls]
119
+ if isinstance(storage_paths, str):
120
+ storage_paths = [storage_paths]
121
+
122
+ assert len(urls) == len(storage_paths)
123
+
124
+ for url_or_filename, storage_path in zip(urls, storage_paths):
125
+ # if storage_path is relative, make it full by prefixing with cache_root.
126
+ if not os.path.isabs(storage_path):
127
+ storage_path = os.path.join(cache_root, storage_path)
128
+
129
+ dirname = os.path.dirname(storage_path)
130
+ if not os.path.exists(dirname):
131
+ os.makedirs(dirname)
132
+
133
+ if os.path.isfile(url_or_filename):
134
+ src, dst = url_or_filename, storage_path
135
+ if not os.path.exists(dst):
136
+ shutil.copyfile(src=src, dst=dst)
137
+ else:
138
+ logging.info("Using existing file {}.".format(dst))
139
+ else:
140
+ if os.path.isdir(storage_path):
141
+ # if only dirname is provided, suffix with basename of URL.
142
+ raise ValueError(
143
+ "Expecting storage_path to be a file path, got directory {}".format(
144
+ storage_path
145
+ )
146
+ )
147
+ else:
148
+ filename = os.path.basename(storage_path)
149
+
150
+ download_url(url=url_or_filename, root=dirname, filename=filename)
151
+
152
+ def _download_vis(self):
153
+
154
+ storage_path = self.config.build_info.get(self.data_type).storage
155
+ storage_path = utils.get_cache_path(storage_path)
156
+
157
+ if not os.path.exists(storage_path):
158
+ warnings.warn(
159
+ f"""
160
+ The specified path {storage_path} for visual inputs does not exist.
161
+ Please provide a correct path to the visual inputs or
162
+ refer to datasets/download_scripts/README.md for downloading instructions.
163
+ """
164
+ )
165
+
166
+ def build(self):
167
+ """
168
+ Create by split datasets inheriting torch.utils.data.Datasets.
169
+
170
+ # build() can be dataset-specific. Overwrite to customize.
171
+ """
172
+ self.build_processors()
173
+
174
+ build_info = self.config.build_info
175
+
176
+ ann_info = build_info.annotations
177
+ vis_info = build_info.get(self.data_type)
178
+
179
+ datasets = dict()
180
+ for split in ann_info.keys():
181
+ if split not in ["train", "val", "test"]:
182
+ continue
183
+
184
+ is_train = split == "train"
185
+
186
+ # processors
187
+ vis_processor = (
188
+ self.vis_processors["train"]
189
+ if is_train
190
+ else self.vis_processors["eval"]
191
+ )
192
+ text_processor = (
193
+ self.text_processors["train"]
194
+ if is_train
195
+ else self.text_processors["eval"]
196
+ )
197
+
198
+ # annotation path
199
+ ann_paths = ann_info.get(split).storage
200
+ if isinstance(ann_paths, str):
201
+ ann_paths = [ann_paths]
202
+
203
+ abs_ann_paths = []
204
+ for ann_path in ann_paths:
205
+ if not os.path.isabs(ann_path):
206
+ ann_path = utils.get_cache_path(ann_path)
207
+ abs_ann_paths.append(ann_path)
208
+ ann_paths = abs_ann_paths
209
+
210
+ # visual data storage path
211
+ vis_path = os.path.join(vis_info.storage, split)
212
+
213
+ if not os.path.isabs(vis_path):
214
+ # vis_path = os.path.join(utils.get_cache_path(), vis_path)
215
+ vis_path = utils.get_cache_path(vis_path)
216
+
217
+ if not os.path.exists(vis_path):
218
+ warnings.warn("storage path {} does not exist.".format(vis_path))
219
+
220
+ # create datasets
221
+ dataset_cls = self.train_dataset_cls if is_train else self.eval_dataset_cls
222
+ datasets[split] = dataset_cls(
223
+ vis_processor=vis_processor,
224
+ text_processor=text_processor,
225
+ ann_paths=ann_paths,
226
+ vis_root=vis_path,
227
+ )
228
+
229
+ return datasets
230
+
231
+
232
+ def load_dataset_config(cfg_path):
233
+ cfg = OmegaConf.load(cfg_path).datasets
234
+ cfg = cfg[list(cfg.keys())[0]]
235
+
236
+ return cfg
video_llama/datasets/builders/image_text_pair_builder.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+ import warnings
4
+
5
+ from video_llama.common.registry import registry
6
+ from video_llama.datasets.builders.base_dataset_builder import BaseDatasetBuilder
7
+ from video_llama.datasets.datasets.laion_dataset import LaionDataset
8
+ from video_llama.datasets.datasets.cc_sbu_dataset import CCSBUDataset, CCSBUAlignDataset
9
+
10
+
11
+ @registry.register_builder("cc_sbu")
12
+ class CCSBUBuilder(BaseDatasetBuilder):
13
+ train_dataset_cls = CCSBUDataset
14
+
15
+ DATASET_CONFIG_DICT = {"default": "configs/datasets/cc_sbu/defaults.yaml"}
16
+
17
+ def _download_ann(self):
18
+ pass
19
+
20
+ def _download_vis(self):
21
+ pass
22
+
23
+ def build(self):
24
+ self.build_processors()
25
+
26
+ build_info = self.config.build_info
27
+
28
+ datasets = dict()
29
+ split = "train"
30
+
31
+ # create datasets
32
+ # [NOTE] return inner_datasets (wds.DataPipeline)
33
+ dataset_cls = self.train_dataset_cls
34
+ datasets[split] = dataset_cls(
35
+ vis_processor=self.vis_processors[split],
36
+ text_processor=self.text_processors[split],
37
+ location=build_info.storage,
38
+ ).inner_dataset
39
+
40
+ return datasets
41
+
42
+
43
+ @registry.register_builder("laion")
44
+ class LaionBuilder(BaseDatasetBuilder):
45
+ train_dataset_cls = LaionDataset
46
+
47
+ DATASET_CONFIG_DICT = {"default": "configs/datasets/laion/defaults.yaml"}
48
+
49
+ def _download_ann(self):
50
+ pass
51
+
52
+ def _download_vis(self):
53
+ pass
54
+
55
+ def build(self):
56
+ self.build_processors()
57
+
58
+ build_info = self.config.build_info
59
+
60
+ datasets = dict()
61
+ split = "train"
62
+
63
+ # create datasets
64
+ # [NOTE] return inner_datasets (wds.DataPipeline)
65
+ dataset_cls = self.train_dataset_cls
66
+ datasets[split] = dataset_cls(
67
+ vis_processor=self.vis_processors[split],
68
+ text_processor=self.text_processors[split],
69
+ location=build_info.storage,
70
+ ).inner_dataset
71
+
72
+ return datasets
73
+
74
+
75
+ @registry.register_builder("cc_sbu_align")
76
+ class CCSBUAlignBuilder(BaseDatasetBuilder):
77
+ train_dataset_cls = CCSBUAlignDataset
78
+
79
+ DATASET_CONFIG_DICT = {
80
+ "default": "configs/datasets/cc_sbu/align.yaml",
81
+ }
82
+
83
+ def build_datasets(self):
84
+ # at this point, all the annotations and image/videos should be all downloaded to the specified locations.
85
+ logging.info("Building datasets...")
86
+ self.build_processors()
87
+
88
+ build_info = self.config.build_info
89
+ storage_path = build_info.storage
90
+
91
+ datasets = dict()
92
+
93
+ if not os.path.exists(storage_path):
94
+ warnings.warn("storage path {} does not exist.".format(storage_path))
95
+
96
+ # create datasets
97
+ dataset_cls = self.train_dataset_cls
98
+ datasets['train'] = dataset_cls(
99
+ vis_processor=self.vis_processors["train"],
100
+ text_processor=self.text_processors["train"],
101
+ ann_paths=[os.path.join(storage_path, 'filter_cap.json')],
102
+ vis_root=os.path.join(storage_path, 'image'),
103
+ )
104
+
105
+ return datasets
106
+
video_llama/datasets/builders/instruct_builder.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+ import warnings
4
+
5
+ from video_llama.common.registry import registry
6
+ from video_llama.datasets.builders.base_dataset_builder import BaseDatasetBuilder
7
+ from video_llama.datasets.datasets.laion_dataset import LaionDataset
8
+ from video_llama.datasets.datasets.llava_instruct_dataset import Instruct_Dataset
9
+ from video_llama.datasets.datasets.video_instruct_dataset import Video_Instruct_Dataset
10
+
11
+ @registry.register_builder("instruct")
12
+ class Instruct_Builder(BaseDatasetBuilder):
13
+ train_dataset_cls = Instruct_Dataset
14
+
15
+ DATASET_CONFIG_DICT = {"default": "configs/datasets/instruct/defaults.yaml"}
16
+
17
+ def _download_ann(self):
18
+ pass
19
+
20
+ def _download_vis(self):
21
+ pass
22
+
23
+ def build(self):
24
+ self.build_processors()
25
+ datasets = dict()
26
+ split = "train"
27
+
28
+ build_info = self.config.build_info
29
+ dataset_cls = self.train_dataset_cls
30
+ if self.config.num_video_query_token:
31
+ num_video_query_token = self.config.num_video_query_token
32
+ else:
33
+ num_video_query_token = 32
34
+
35
+ if self.config.tokenizer_name:
36
+ tokenizer_name = self.config.tokenizer_name
37
+ else:
38
+ tokenizer_name = '/mnt/workspace/ckpt/vicuna-13b/'
39
+
40
+
41
+ datasets[split] = dataset_cls(
42
+ vis_processor=self.vis_processors[split],
43
+ text_processor=self.text_processors[split],
44
+ vis_root=build_info.videos_dir,
45
+ ann_root=build_info.anno_dir,
46
+ num_video_query_token = num_video_query_token,
47
+ tokenizer_name = tokenizer_name,
48
+ data_type = self.config.data_type
49
+ )
50
+
51
+ return datasets
52
+
53
+ @registry.register_builder("webvid_instruct")
54
+ class WebvidInstruct_Builder(Instruct_Builder):
55
+ train_dataset_cls = Video_Instruct_Dataset
56
+
57
+ DATASET_CONFIG_DICT = {
58
+ "default": "configs/datasets/instruct/webvid_instruct.yaml",
59
+ }
60
+
61
+ @registry.register_builder("webvid_instruct_zh")
62
+ class WebvidInstruct_zh_Builder(Instruct_Builder):
63
+ train_dataset_cls = Video_Instruct_Dataset
64
+
65
+ DATASET_CONFIG_DICT = {
66
+ "default": "configs/datasets/instruct/webvid_instruct.yaml",
67
+ }
68
+
69
+
70
+
71
+ @registry.register_builder("llava_instruct")
72
+ class LlavaInstruct_Builder(Instruct_Builder):
73
+ train_dataset_cls = Instruct_Dataset
74
+
75
+ DATASET_CONFIG_DICT = {
76
+ "default": "configs/datasets/instruct/llava_instruct.yaml",
77
+ }
78
+
video_llama/datasets/builders/video_caption_builder.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+ import warnings
4
+
5
+ from video_llama.common.registry import registry
6
+ from video_llama.datasets.builders.base_dataset_builder import BaseDatasetBuilder
7
+ from video_llama.datasets.datasets.webvid_datasets import WebvidDataset
8
+
9
+ @registry.register_builder("webvid")
10
+ class WebvidBuilder(BaseDatasetBuilder):
11
+ train_dataset_cls = WebvidDataset
12
+ DATASET_CONFIG_DICT = {"default": "configs/datasets/webvid/defaults.yaml"}
13
+
14
+ def _download_ann(self):
15
+ pass
16
+
17
+ def _download_vis(self):
18
+ pass
19
+
20
+ def build(self):
21
+ self.build_processors()
22
+ datasets = dict()
23
+ split = "train"
24
+
25
+ build_info = self.config.build_info
26
+ dataset_cls = self.train_dataset_cls
27
+ datasets[split] = dataset_cls(
28
+ vis_processor=self.vis_processors[split],
29
+ text_processor=self.text_processors[split],
30
+ vis_root=build_info.videos_dir,
31
+ ann_root=build_info.anno_dir
32
+ )
33
+
34
+ return datasets
video_llama/datasets/data_utils.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import gzip
9
+ import logging
10
+ import os
11
+ import random as rnd
12
+ import tarfile
13
+ import zipfile
14
+ import random
15
+ from typing import List
16
+ from tqdm import tqdm
17
+
18
+ import decord
19
+ from decord import VideoReader
20
+ import webdataset as wds
21
+ import numpy as np
22
+ import torch
23
+ from torch.utils.data.dataset import IterableDataset
24
+
25
+ from video_llama.common.registry import registry
26
+ from video_llama.datasets.datasets.base_dataset import ConcatDataset
27
+
28
+
29
+ decord.bridge.set_bridge("torch")
30
+ MAX_INT = registry.get("MAX_INT")
31
+
32
+
33
+ class ChainDataset(wds.DataPipeline):
34
+ r"""Dataset for chaining multiple :class:`DataPipeline` s.
35
+
36
+ This class is useful to assemble different existing dataset streams. The
37
+ chaining operation is done on-the-fly, so concatenating large-scale
38
+ datasets with this class will be efficient.
39
+
40
+ Args:
41
+ datasets (iterable of IterableDataset): datasets to be chained together
42
+ """
43
+ def __init__(self, datasets: List[wds.DataPipeline]) -> None:
44
+ super().__init__()
45
+ self.datasets = datasets
46
+ self.prob = []
47
+ self.names = []
48
+ for dataset in self.datasets:
49
+ if hasattr(dataset, 'name'):
50
+ self.names.append(dataset.name)
51
+ else:
52
+ self.names.append('Unknown')
53
+ if hasattr(dataset, 'sample_ratio'):
54
+ self.prob.append(dataset.sample_ratio)
55
+ else:
56
+ self.prob.append(1)
57
+ logging.info("One of the datapipeline doesn't define ratio and set to 1 automatically.")
58
+
59
+ def __iter__(self):
60
+ datastreams = [iter(dataset) for dataset in self.datasets]
61
+ while True:
62
+ select_datastream = random.choices(datastreams, weights=self.prob, k=1)[0]
63
+ yield next(select_datastream)
64
+
65
+
66
+ def apply_to_sample(f, sample):
67
+ if len(sample) == 0:
68
+ return {}
69
+
70
+ def _apply(x):
71
+ if torch.is_tensor(x):
72
+ return f(x)
73
+ elif isinstance(x, dict):
74
+ return {key: _apply(value) for key, value in x.items()}
75
+ elif isinstance(x, list):
76
+ return [_apply(x) for x in x]
77
+ else:
78
+ return x
79
+
80
+ return _apply(sample)
81
+
82
+
83
+ def move_to_cuda(sample):
84
+ def _move_to_cuda(tensor):
85
+ return tensor.cuda()
86
+
87
+ return apply_to_sample(_move_to_cuda, sample)
88
+
89
+
90
+ def prepare_sample(samples, cuda_enabled=True):
91
+ if cuda_enabled:
92
+ samples = move_to_cuda(samples)
93
+
94
+ # TODO fp16 support
95
+
96
+ return samples
97
+
98
+
99
+ def reorg_datasets_by_split(datasets):
100
+ """
101
+ Organizes datasets by split.
102
+
103
+ Args:
104
+ datasets: dict of torch.utils.data.Dataset objects by name.
105
+
106
+ Returns:
107
+ Dict of datasets by split {split_name: List[Datasets]}.
108
+ """
109
+ # if len(datasets) == 1:
110
+ # return datasets[list(datasets.keys())[0]]
111
+ # else:
112
+ reorg_datasets = dict()
113
+
114
+ # reorganize by split
115
+ for _, dataset in datasets.items():
116
+ for split_name, dataset_split in dataset.items():
117
+ if split_name not in reorg_datasets:
118
+ reorg_datasets[split_name] = [dataset_split]
119
+ else:
120
+ reorg_datasets[split_name].append(dataset_split)
121
+
122
+ return reorg_datasets
123
+
124
+
125
+ def concat_datasets(datasets):
126
+ """
127
+ Concatenates multiple datasets into a single dataset.
128
+
129
+ It supports may-style datasets and DataPipeline from WebDataset. Currently, does not support
130
+ generic IterableDataset because it requires creating separate samplers.
131
+
132
+ Now only supports conctenating training datasets and assuming validation and testing
133
+ have only a single dataset. This is because metrics should not be computed on the concatenated
134
+ datasets.
135
+
136
+ Args:
137
+ datasets: dict of torch.utils.data.Dataset objects by split.
138
+
139
+ Returns:
140
+ Dict of concatenated datasets by split, "train" is the concatenation of multiple datasets,
141
+ "val" and "test" remain the same.
142
+
143
+ If the input training datasets contain both map-style and DataPipeline datasets, returns
144
+ a tuple, where the first element is a concatenated map-style dataset and the second
145
+ element is a chained DataPipeline dataset.
146
+
147
+ """
148
+ # concatenate datasets in the same split
149
+ for split_name in datasets:
150
+ if split_name != "train":
151
+ assert (
152
+ len(datasets[split_name]) == 1
153
+ ), "Do not support multiple {} datasets.".format(split_name)
154
+ datasets[split_name] = datasets[split_name][0]
155
+ else:
156
+ iterable_datasets, map_datasets = [], []
157
+ for dataset in datasets[split_name]:
158
+ if isinstance(dataset, wds.DataPipeline):
159
+ logging.info(
160
+ "Dataset {} is IterableDataset, can't be concatenated.".format(
161
+ dataset
162
+ )
163
+ )
164
+ iterable_datasets.append(dataset)
165
+ elif isinstance(dataset, IterableDataset):
166
+ raise NotImplementedError(
167
+ "Do not support concatenation of generic IterableDataset."
168
+ )
169
+ else:
170
+ map_datasets.append(dataset)
171
+
172
+ # if len(iterable_datasets) > 0:
173
+ # concatenate map-style datasets and iterable-style datasets separately
174
+ if len(iterable_datasets) > 1:
175
+ chained_datasets = (
176
+ ChainDataset(iterable_datasets)
177
+ )
178
+ elif len(iterable_datasets) == 1:
179
+ chained_datasets = iterable_datasets[0]
180
+ else:
181
+ chained_datasets = None
182
+
183
+ concat_datasets = (
184
+ ConcatDataset(map_datasets) if len(map_datasets) > 0 else None
185
+ )
186
+
187
+ train_datasets = concat_datasets, chained_datasets
188
+ train_datasets = tuple([x for x in train_datasets if x is not None])
189
+ train_datasets = (
190
+ train_datasets[0] if len(train_datasets) == 1 else train_datasets
191
+ )
192
+
193
+ datasets[split_name] = train_datasets
194
+
195
+ return datasets
196
+
video_llama/datasets/datasets/__init__.py ADDED
File without changes
video_llama/datasets/datasets/base_dataset.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import json
9
+ from typing import Iterable
10
+
11
+ from torch.utils.data import Dataset, ConcatDataset
12
+ from torch.utils.data.dataloader import default_collate
13
+
14
+
15
+ class BaseDataset(Dataset):
16
+ def __init__(
17
+ self, vis_processor=None, text_processor=None, vis_root=None, ann_paths=[]
18
+ ):
19
+ """
20
+ vis_root (string): Root directory of images (e.g. coco/images/)
21
+ ann_root (string): directory to store the annotation file
22
+ """
23
+ self.vis_root = vis_root
24
+
25
+ self.annotation = []
26
+ for ann_path in ann_paths:
27
+ self.annotation.extend(json.load(open(ann_path, "r"))['annotations'])
28
+
29
+ self.vis_processor = vis_processor
30
+ self.text_processor = text_processor
31
+
32
+ self._add_instance_ids()
33
+
34
+ def __len__(self):
35
+ return len(self.annotation)
36
+
37
+ def collater(self, samples):
38
+ return default_collate(samples)
39
+
40
+ def set_processors(self, vis_processor, text_processor):
41
+ self.vis_processor = vis_processor
42
+ self.text_processor = text_processor
43
+
44
+ def _add_instance_ids(self, key="instance_id"):
45
+ for idx, ann in enumerate(self.annotation):
46
+ ann[key] = str(idx)
47
+
48
+
49
+ class ConcatDataset(ConcatDataset):
50
+ def __init__(self, datasets: Iterable[Dataset]) -> None:
51
+ super().__init__(datasets)
52
+
53
+ def collater(self, samples):
54
+ # TODO For now only supports datasets with same underlying collater implementations
55
+
56
+ all_keys = set()
57
+ for s in samples:
58
+ all_keys.update(s)
59
+
60
+ shared_keys = all_keys
61
+ for s in samples:
62
+ shared_keys = shared_keys & set(s.keys())
63
+
64
+ samples_shared_keys = []
65
+ for s in samples:
66
+ samples_shared_keys.append({k: s[k] for k in s.keys() if k in shared_keys})
67
+
68
+ return self.datasets[0].collater(samples_shared_keys)
video_llama/datasets/datasets/caption_datasets.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import os
9
+ from collections import OrderedDict
10
+
11
+ from video_llama.datasets.datasets.base_dataset import BaseDataset
12
+ from PIL import Image
13
+
14
+
15
+ class __DisplMixin:
16
+ def displ_item(self, index):
17
+ sample, ann = self.__getitem__(index), self.annotation[index]
18
+
19
+ return OrderedDict(
20
+ {
21
+ "file": ann["image"],
22
+ "caption": ann["caption"],
23
+ "image": sample["image"],
24
+ }
25
+ )
26
+
27
+
28
+ class CaptionDataset(BaseDataset, __DisplMixin):
29
+ def __init__(self, vis_processor, text_processor, vis_root, ann_paths):
30
+ """
31
+ vis_root (string): Root directory of images (e.g. coco/images/)
32
+ ann_root (string): directory to store the annotation file
33
+ """
34
+ super().__init__(vis_processor, text_processor, vis_root, ann_paths)
35
+
36
+ self.img_ids = {}
37
+ n = 0
38
+ for ann in self.annotation:
39
+ img_id = ann["image_id"]
40
+ if img_id not in self.img_ids.keys():
41
+ self.img_ids[img_id] = n
42
+ n += 1
43
+
44
+ def __getitem__(self, index):
45
+
46
+ # TODO this assumes image input, not general enough
47
+ ann = self.annotation[index]
48
+
49
+ img_file = '{:0>12}.jpg'.format(ann["image_id"])
50
+ image_path = os.path.join(self.vis_root, img_file)
51
+ image = Image.open(image_path).convert("RGB")
52
+
53
+ image = self.vis_processor(image)
54
+ caption = self.text_processor(ann["caption"])
55
+
56
+ return {
57
+ "image": image,
58
+ "text_input": caption,
59
+ "image_id": self.img_ids[ann["image_id"]],
60
+ }
61
+
62
+
63
+ class CaptionEvalDataset(BaseDataset, __DisplMixin):
64
+ def __init__(self, vis_processor, text_processor, vis_root, ann_paths):
65
+ """
66
+ vis_root (string): Root directory of images (e.g. coco/images/)
67
+ ann_root (string): directory to store the annotation file
68
+ split (string): val or test
69
+ """
70
+ super().__init__(vis_processor, text_processor, vis_root, ann_paths)
71
+
72
+ def __getitem__(self, index):
73
+
74
+ ann = self.annotation[index]
75
+
76
+ image_path = os.path.join(self.vis_root, ann["image"])
77
+ image = Image.open(image_path).convert("RGB")
78
+
79
+ image = self.vis_processor(image)
80
+
81
+ return {
82
+ "image": image,
83
+ "image_id": ann["image_id"],
84
+ "instance_id": ann["instance_id"],
85
+ }
video_llama/datasets/datasets/cc_sbu_dataset.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from PIL import Image
3
+ import webdataset as wds
4
+ from video_llama.datasets.datasets.base_dataset import BaseDataset
5
+ from video_llama.datasets.datasets.caption_datasets import CaptionDataset
6
+
7
+
8
+ class CCSBUDataset(BaseDataset):
9
+ def __init__(self, vis_processor, text_processor, location):
10
+ super().__init__(vis_processor=vis_processor, text_processor=text_processor)
11
+
12
+ self.inner_dataset = wds.DataPipeline(
13
+ wds.ResampledShards(location),
14
+ wds.tarfile_to_samples(handler=wds.warn_and_continue),
15
+ wds.shuffle(1000, handler=wds.warn_and_continue),
16
+ wds.decode("pilrgb", handler=wds.warn_and_continue),
17
+ wds.to_tuple("jpg", "json", handler=wds.warn_and_continue),
18
+ wds.map_tuple(self.vis_processor, handler=wds.warn_and_continue),
19
+ wds.map(self.to_dict, handler=wds.warn_and_continue),
20
+ )
21
+
22
+ def to_dict(self, sample):
23
+ return {
24
+ "image": sample[0],
25
+ "text_input": self.text_processor(sample[1]["caption"]),
26
+ "type":'image',
27
+ }
28
+
29
+
30
+ class CCSBUAlignDataset(CaptionDataset):
31
+
32
+ def __getitem__(self, index):
33
+
34
+ # TODO this assumes image input, not general enough
35
+ ann = self.annotation[index]
36
+
37
+ img_file = '{}.jpg'.format(ann["image_id"])
38
+ image_path = os.path.join(self.vis_root, img_file)
39
+ image = Image.open(image_path).convert("RGB")
40
+
41
+ image = self.vis_processor(image)
42
+ caption = ann["caption"]
43
+
44
+ return {
45
+ "image": image,
46
+ "text_input": caption,
47
+ "image_id": self.img_ids[ann["image_id"]],
48
+ "type":'image',
49
+ }
video_llama/datasets/datasets/dataloader_utils.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import time
9
+ import random
10
+ import torch
11
+ from video_llama.datasets.data_utils import move_to_cuda
12
+ from torch.utils.data import DataLoader
13
+
14
+
15
+ class MultiIterLoader:
16
+ """
17
+ A simple wrapper for iterating over multiple iterators.
18
+
19
+ Args:
20
+ loaders (List[Loader]): List of Iterator loaders.
21
+ ratios (List[float]): List of ratios to sample from each loader. If None, all loaders are sampled uniformly.
22
+ """
23
+
24
+ def __init__(self, loaders, ratios=None):
25
+ # assert all loaders has __next__ method
26
+ for loader in loaders:
27
+ assert hasattr(
28
+ loader, "__next__"
29
+ ), "Loader {} has no __next__ method.".format(loader)
30
+
31
+ if ratios is None:
32
+ ratios = [1.0] * len(loaders)
33
+ else:
34
+ assert len(ratios) == len(loaders)
35
+ ratios = [float(ratio) / sum(ratios) for ratio in ratios]
36
+
37
+ self.loaders = loaders
38
+ self.ratios = ratios
39
+
40
+ def __next__(self):
41
+ # random sample from each loader by ratio
42
+ loader_idx = random.choices(range(len(self.loaders)), self.ratios, k=1)[0]
43
+ return next(self.loaders[loader_idx])
44
+
45
+
46
+ class PrefetchLoader(object):
47
+ """
48
+ Modified from https://github.com/ChenRocks/UNITER.
49
+
50
+ overlap compute and cuda data transfer
51
+ (copied and then modified from nvidia apex)
52
+ """
53
+
54
+ def __init__(self, loader):
55
+ self.loader = loader
56
+ self.stream = torch.cuda.Stream()
57
+
58
+ def __iter__(self):
59
+ loader_it = iter(self.loader)
60
+ self.preload(loader_it)
61
+ batch = self.next(loader_it)
62
+ while batch is not None:
63
+ is_tuple = isinstance(batch, tuple)
64
+ if is_tuple:
65
+ task, batch = batch
66
+
67
+ if is_tuple:
68
+ yield task, batch
69
+ else:
70
+ yield batch
71
+ batch = self.next(loader_it)
72
+
73
+ def __len__(self):
74
+ return len(self.loader)
75
+
76
+ def preload(self, it):
77
+ try:
78
+ self.batch = next(it)
79
+ except StopIteration:
80
+ self.batch = None
81
+ return
82
+ # if record_stream() doesn't work, another option is to make sure
83
+ # device inputs are created on the main stream.
84
+ # self.next_input_gpu = torch.empty_like(self.next_input,
85
+ # device='cuda')
86
+ # self.next_target_gpu = torch.empty_like(self.next_target,
87
+ # device='cuda')
88
+ # Need to make sure the memory allocated for next_* is not still in use
89
+ # by the main stream at the time we start copying to next_*:
90
+ # self.stream.wait_stream(torch.cuda.current_stream())
91
+ with torch.cuda.stream(self.stream):
92
+ self.batch = move_to_cuda(self.batch)
93
+ # more code for the alternative if record_stream() doesn't work:
94
+ # copy_ will record the use of the pinned source tensor in this
95
+ # side stream.
96
+ # self.next_input_gpu.copy_(self.next_input, non_blocking=True)
97
+ # self.next_target_gpu.copy_(self.next_target, non_blocking=True)
98
+ # self.next_input = self.next_input_gpu
99
+ # self.next_target = self.next_target_gpu
100
+
101
+ def next(self, it):
102
+ torch.cuda.current_stream().wait_stream(self.stream)
103
+ batch = self.batch
104
+ if batch is not None:
105
+ record_cuda_stream(batch)
106
+ self.preload(it)
107
+ return batch
108
+
109
+ def __getattr__(self, name):
110
+ method = self.loader.__getattribute__(name)
111
+ return method
112
+
113
+
114
+ def record_cuda_stream(batch):
115
+ if isinstance(batch, torch.Tensor):
116
+ batch.record_stream(torch.cuda.current_stream())
117
+ elif isinstance(batch, list) or isinstance(batch, tuple):
118
+ for t in batch:
119
+ record_cuda_stream(t)
120
+ elif isinstance(batch, dict):
121
+ for t in batch.values():
122
+ record_cuda_stream(t)
123
+ else:
124
+ pass
125
+
126
+
127
+ class IterLoader:
128
+ """
129
+ A wrapper to convert DataLoader as an infinite iterator.
130
+
131
+ Modified from:
132
+ https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/iter_based_runner.py
133
+ """
134
+
135
+ def __init__(self, dataloader: DataLoader, use_distributed: bool = False):
136
+ self._dataloader = dataloader
137
+ self.iter_loader = iter(self._dataloader)
138
+ self._use_distributed = use_distributed
139
+ self._epoch = 0
140
+
141
+ @property
142
+ def epoch(self) -> int:
143
+ return self._epoch
144
+
145
+ def __next__(self):
146
+ try:
147
+ data = next(self.iter_loader)
148
+ except StopIteration:
149
+ self._epoch += 1
150
+ if hasattr(self._dataloader.sampler, "set_epoch") and self._use_distributed:
151
+ self._dataloader.sampler.set_epoch(self._epoch)
152
+ time.sleep(2) # Prevent possible deadlock during epoch transition
153
+ self.iter_loader = iter(self._dataloader)
154
+ data = next(self.iter_loader)
155
+
156
+ return data
157
+
158
+ def __iter__(self):
159
+ return self
160
+
161
+ def __len__(self):
162
+ return len(self._dataloader)
video_llama/datasets/datasets/laion_dataset.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import webdataset as wds
9
+ from video_llama.datasets.datasets.base_dataset import BaseDataset
10
+
11
+
12
+ class LaionDataset(BaseDataset):
13
+ def __init__(self, vis_processor, text_processor, location):
14
+ super().__init__(vis_processor=vis_processor, text_processor=text_processor)
15
+
16
+ self.inner_dataset = wds.DataPipeline(
17
+ wds.ResampledShards(location),
18
+ wds.tarfile_to_samples(handler=wds.warn_and_continue),
19
+ wds.shuffle(1000, handler=wds.warn_and_continue),
20
+ wds.decode("pilrgb", handler=wds.warn_and_continue),
21
+ wds.to_tuple("jpg", "json", handler=wds.warn_and_continue),
22
+ wds.map_tuple(self.vis_processor, handler=wds.warn_and_continue),
23
+ wds.map(self.to_dict, handler=wds.warn_and_continue),
24
+ )
25
+
26
+ def to_dict(self, sample):
27
+ return {
28
+ "image": sample[0],
29
+ "text_input": self.text_processor(sample[1]["caption"]),
30
+ }
31
+
video_llama/datasets/datasets/llava_instruct_dataset.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from video_llama.datasets.datasets.base_dataset import BaseDataset
3
+ from video_llama.datasets.datasets.caption_datasets import CaptionDataset
4
+ import pandas as pd
5
+ import decord
6
+ from decord import VideoReader
7
+ import random
8
+ import torch
9
+ from torch.utils.data.dataloader import default_collate
10
+ from PIL import Image
11
+ from typing import Dict, Optional, Sequence
12
+ import transformers
13
+ import pathlib
14
+ import json
15
+ from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaTokenizer
16
+ from video_llama.conversation.conversation_video import Conversation,SeparatorStyle
17
+ DEFAULT_IMAGE_PATCH_TOKEN = '<ImageHere>'
18
+ DEFAULT_IMAGE_TOKEN = "<image>"
19
+ import copy
20
+ IGNORE_INDEX = -100
21
+ image_conversation = Conversation(
22
+ system="",
23
+ roles=("Human", "Assistant"),
24
+ messages=[],
25
+ offset=0,
26
+ sep_style=SeparatorStyle.SINGLE,
27
+ sep="###",
28
+ )
29
+ IGNORE_INDEX = -100
30
+
31
+ class Instruct_Dataset(BaseDataset):
32
+ def __init__(self, vis_processor, text_processor, vis_root, ann_root,num_video_query_token=32,tokenizer_name = '/mnt/workspace/ckpt/vicuna-13b/',data_type = 'image'):
33
+ """
34
+ vis_root (string): Root directory of Llava images (e.g. webvid_eval/video/)
35
+ ann_root (string): Root directory of video (e.g. webvid_eval/annotations/)
36
+ split (string): val or test
37
+ """
38
+ super().__init__(vis_processor=vis_processor, text_processor=text_processor)
39
+
40
+ data_path = pathlib.Path(ann_root)
41
+ with data_path.open(encoding='utf-8') as f:
42
+ self.annotation = json.load(f)
43
+
44
+ self.vis_root = vis_root
45
+ self.resize_size = 224
46
+ self.num_frm = 8
47
+ self.tokenizer = LlamaTokenizer.from_pretrained(tokenizer_name, use_fast=False)
48
+ self.tokenizer.pad_token = self.tokenizer.eos_token
49
+ self.tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
50
+ self.num_video_query_token = num_video_query_token
51
+ self.IMAGE_PATCH_TOKEN_ID = self.tokenizer.get_vocab()[DEFAULT_IMAGE_PATCH_TOKEN]
52
+
53
+ self.transform = AlproVideoTrainProcessor(
54
+ image_size=self.resize_size, n_frms = self.num_frm
55
+ ).transform
56
+ self.data_type = data_type
57
+
58
+ def _get_image_path(self, sample):
59
+ rel_video_fp ='COCO_train2014_' + sample['image']
60
+ full_video_fp = os.path.join(self.vis_root, rel_video_fp)
61
+ return full_video_fp
62
+
63
+ def __getitem__(self, index):
64
+ num_retries = 10 # skip error videos
65
+ for _ in range(num_retries):
66
+ try:
67
+ sample = self.annotation[index]
68
+
69
+ image_path = self._get_image_path(sample)
70
+ conversation_list = sample['conversations']
71
+ image = Image.open(image_path).convert("RGB")
72
+
73
+ image = self.vis_processor(image)
74
+ # text = self.text_processor(text)
75
+ sources = preprocess_multimodal(copy.deepcopy(conversation_list), None, cur_token_len=self.num_video_query_token)
76
+ data_dict = preprocess(
77
+ sources,
78
+ self.tokenizer)
79
+ data_dict = dict(input_ids=data_dict["input_ids"][0],
80
+ labels=data_dict["labels"][0])
81
+
82
+ # image exist in the data
83
+ data_dict['image'] = image
84
+ except:
85
+ print(f"Failed to load examples with image: {image_path}. "
86
+ f"Will randomly sample an example as a replacement.")
87
+ index = random.randint(0, len(self) - 1)
88
+ continue
89
+ break
90
+ else:
91
+ raise RuntimeError(f"Failed to fetch image after {num_retries} retries.")
92
+ # "image_id" is kept to stay compatible with the COCO evaluation format
93
+ return {
94
+ "image": image,
95
+ "text_input": data_dict["input_ids"],
96
+ "labels": data_dict["labels"],
97
+ "type":'image',
98
+ }
99
+
100
+ def __len__(self):
101
+ return len(self.annotation)
102
+
103
+ def collater(self, instances):
104
+ input_ids, labels = tuple([instance[key] for instance in instances]
105
+ for key in ("text_input", "labels"))
106
+ input_ids = torch.nn.utils.rnn.pad_sequence(
107
+ input_ids,
108
+ batch_first=True,
109
+ padding_value=self.tokenizer.pad_token_id)
110
+ labels = torch.nn.utils.rnn.pad_sequence(labels,
111
+ batch_first=True,
112
+ padding_value=IGNORE_INDEX)
113
+ batch = dict(
114
+ input_ids=input_ids,
115
+ labels=labels,
116
+ attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
117
+ )
118
+
119
+ if 'image' in instances[0]:
120
+ images = [instance['image'] for instance in instances]
121
+ if all(x is not None and x.shape == images[0].shape for x in images):
122
+ batch['images'] = torch.stack(images)
123
+ else:
124
+ batch['images'] = images
125
+ batch['conv_type'] = 'multi'
126
+ return batch
127
+
128
+
129
+ def preprocess_multimodal(
130
+ conversation_list: Sequence[str],
131
+ multimodal_cfg: dict,
132
+ cur_token_len: int,
133
+ ) -> Dict:
134
+ # 将conversational list中
135
+ is_multimodal = True
136
+ # image_token_len = multimodal_cfg['image_token_len']
137
+ image_token_len = cur_token_len
138
+
139
+ for sentence in conversation_list:
140
+ replace_token = '<Image>'+DEFAULT_IMAGE_PATCH_TOKEN * image_token_len+'/<Image>'
141
+ sentence["value"] = sentence["value"].replace(DEFAULT_IMAGE_TOKEN, replace_token)
142
+
143
+ return [conversation_list]
144
+
145
+ def _add_speaker_and_signal(header, source, get_conversation=True):
146
+ """Add speaker and start/end signal on each round."""
147
+ BEGIN_SIGNAL = "###"
148
+ END_SIGNAL = "\n"
149
+ conversation = header
150
+ for sentence in source:
151
+ from_str = sentence["from"]
152
+ if from_str.lower() == "human":
153
+ from_str = image_conversation.roles[0]
154
+ elif from_str.lower() == "gpt":
155
+ from_str = image_conversation.roles[1]
156
+ else:
157
+ from_str = 'unknown'
158
+ sentence["value"] = (BEGIN_SIGNAL + from_str + ": " +
159
+ sentence["value"] + END_SIGNAL)
160
+ if get_conversation:
161
+ conversation += sentence["value"]
162
+ conversation += BEGIN_SIGNAL
163
+ return conversation
164
+
165
+ def _tokenize_fn(strings: Sequence[str],
166
+ tokenizer: transformers.PreTrainedTokenizer) -> Dict:
167
+ """Tokenize a list of strings."""
168
+ tokenized_list = [
169
+ tokenizer(
170
+ text,
171
+ return_tensors="pt",
172
+ padding="longest",
173
+ max_length=512,
174
+ truncation=True,
175
+ ) for text in strings
176
+ ]
177
+ input_ids = labels = [
178
+ tokenized.input_ids[0] for tokenized in tokenized_list
179
+ ]
180
+ input_ids_lens = labels_lens = [
181
+ tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item()
182
+ for tokenized in tokenized_list
183
+ ]
184
+ return dict(
185
+ input_ids=input_ids,
186
+ labels=labels,
187
+ input_ids_lens=input_ids_lens,
188
+ labels_lens=labels_lens,
189
+ )
190
+
191
+ def preprocess(
192
+ sources: Sequence[str],
193
+ tokenizer: transformers.PreTrainedTokenizer,
194
+ ) -> Dict:
195
+ """
196
+ Given a list of sources, each is a conversation list. This transform:
197
+ 1. Add signal '### ' at the beginning each sentence, with end signal '\n';
198
+ 2. Concatenate conversations together;
199
+ 3. Tokenize the concatenated conversation;
200
+ 4. Make a deepcopy as the target. Mask human words with IGNORE_INDEX.
201
+ """
202
+ # add end signal and concatenate together
203
+ conversations = []
204
+ for source in sources:
205
+ header = f"{image_conversation.system}\n\n"
206
+ conversation = _add_speaker_and_signal(header, source)
207
+ conversations.append(conversation)
208
+ # tokenize conversations
209
+ conversations_tokenized = _tokenize_fn(conversations, tokenizer)
210
+ input_ids = conversations_tokenized["input_ids"]
211
+ targets = copy.deepcopy(input_ids)
212
+ for target, source in zip(targets, sources):
213
+ tokenized_lens = _tokenize_fn([header] + [s["value"] for s in source],
214
+ tokenizer)["input_ids_lens"]
215
+ speakers = [sentence["from"] for sentence in source]
216
+ _mask_targets(target, tokenized_lens, speakers)
217
+
218
+ return dict(input_ids=input_ids, labels=targets)
219
+
220
+ def _mask_targets(target, tokenized_lens, speakers):
221
+ # cur_idx = 0
222
+ cur_idx = tokenized_lens[0]
223
+ tokenized_lens = tokenized_lens[1:]
224
+ target[:cur_idx] = IGNORE_INDEX
225
+ for tokenized_len, speaker in zip(tokenized_lens, speakers):
226
+ if speaker == "human":
227
+ target[cur_idx+2:cur_idx + tokenized_len] = IGNORE_INDEX
228
+ cur_idx += tokenized_len
video_llama/datasets/datasets/video_instruct_dataset.py ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from video_llama.datasets.datasets.base_dataset import BaseDataset
3
+ from video_llama.datasets.datasets.caption_datasets import CaptionDataset
4
+ import pandas as pd
5
+ import decord
6
+ from decord import VideoReader
7
+ import random
8
+ import torch
9
+ from torch.utils.data.dataloader import default_collate
10
+ from PIL import Image
11
+ from typing import Dict, Optional, Sequence
12
+ import transformers
13
+ import pathlib
14
+ import json
15
+ from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaTokenizer
16
+ import copy
17
+ from video_llama.processors import transforms_video,AlproVideoTrainProcessor
18
+ from torchvision import transforms
19
+ from video_llama.processors.video_processor import ToTHWC,ToUint8,load_video
20
+ from video_llama.conversation.conversation_video import Conversation,SeparatorStyle
21
+
22
+ DEFAULT_IMAGE_PATCH_TOKEN = '<ImageHere>'
23
+ video_conversation = Conversation(
24
+ system="",
25
+ roles=("Human", "Assistant"),
26
+ messages=[],
27
+ offset=0,
28
+ sep_style=SeparatorStyle.SINGLE,
29
+ sep="###",
30
+ )
31
+ IGNORE_INDEX = -100
32
+
33
+ class Video_Instruct_Dataset(BaseDataset):
34
+ def __init__(self, vis_processor, text_processor, vis_root, ann_root,num_video_query_token=32,tokenizer_name = '/mnt/workspace/ckpt/vicuna-13b/',data_type = 'video'):
35
+ """
36
+ vis_root (string): Root directory of Llava images (e.g. webvid_eval/video/)
37
+ ann_root (string): Root directory of video (e.g. webvid_eval/annotations/)
38
+ split (string): val or test
39
+ """
40
+ super().__init__(vis_processor=vis_processor, text_processor=text_processor)
41
+
42
+ data_path = pathlib.Path(ann_root)
43
+ with data_path.open(encoding='utf-8') as f:
44
+ self.annotation = json.load(f)
45
+
46
+ self.num_video_query_token = num_video_query_token
47
+ self.vis_root = vis_root
48
+ self.resize_size = 224
49
+ self.num_frm = 8
50
+ self.tokenizer = LlamaTokenizer.from_pretrained(tokenizer_name, use_fast=False)
51
+ self.tokenizer.pad_token = self.tokenizer.eos_token
52
+ self.tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
53
+ self.IMAGE_PATCH_TOKEN_ID = self.tokenizer.get_vocab()[DEFAULT_IMAGE_PATCH_TOKEN]
54
+
55
+ self.transform = AlproVideoTrainProcessor(
56
+ image_size=self.resize_size, n_frms = self.num_frm
57
+ ).transform
58
+ self.data_type = data_type
59
+
60
+ def _get_video_path(self, sample):
61
+ rel_video_fp = sample['video']
62
+ full_video_fp = os.path.join(self.vis_root, rel_video_fp)
63
+ return full_video_fp
64
+
65
+ def __getitem__(self, index):
66
+ num_retries = 10 # skip error videos
67
+ for _ in range(num_retries):
68
+ try:
69
+ sample = self.annotation[index]
70
+
71
+ video_path = self._get_video_path(sample)
72
+ conversation_list = sample['QA']
73
+
74
+ video, msg = load_video(
75
+ video_path=video_path,
76
+ n_frms=self.num_frm,
77
+ height=self.resize_size,
78
+ width=self.resize_size,
79
+ sampling ="uniform", return_msg = True
80
+ )
81
+ video = self.transform(video)
82
+ if 'cn' in self.data_type:
83
+ msg = ""
84
+ # 添加视频<DEFAULT_IMAGE_PATCH_TOKEN>,以及msg到convsation list 0
85
+ sources = preprocess_multimodal(copy.deepcopy(conversation_list), None, cur_token_len=self.num_video_query_token,msg = msg)
86
+ new_sources = convert_source_vicuna_format(sources)
87
+
88
+ data_dict = preprocess(
89
+ new_sources,
90
+ self.tokenizer)
91
+ data_dict = dict(input_ids=data_dict["input_ids"][0],
92
+ labels=data_dict["labels"][0])
93
+ # image exist in the data
94
+ data_dict['image'] = video
95
+ except:
96
+ print(f"Failed to load examples with video: {video_path}. "
97
+ f"Will randomly sample an example as a replacement.")
98
+ index = random.randint(0, len(self) - 1)
99
+ continue
100
+ break
101
+ else:
102
+ raise RuntimeError(f"Failed to fetch video after {num_retries} retries.")
103
+ # "image_id" is kept to stay compatible with the COCO evaluation format
104
+ return {
105
+ "image": video,
106
+ "text_input": data_dict["input_ids"],
107
+ "labels": data_dict["labels"],
108
+ "type":'video',
109
+ }
110
+
111
+ def __len__(self):
112
+ return len(self.annotation)
113
+
114
+ def collater(self, instances):
115
+ input_ids, labels = tuple([instance[key] for instance in instances]
116
+ for key in ("text_input", "labels"))
117
+ input_ids = torch.nn.utils.rnn.pad_sequence(
118
+ input_ids,
119
+ batch_first=True,
120
+ padding_value=self.tokenizer.pad_token_id)
121
+ labels = torch.nn.utils.rnn.pad_sequence(labels,
122
+ batch_first=True,
123
+ padding_value=IGNORE_INDEX)
124
+ batch = dict(
125
+ input_ids=input_ids,
126
+ labels=labels,
127
+ attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
128
+ )
129
+
130
+ if 'image' in instances[0]:
131
+ images = [instance['image'] for instance in instances]
132
+ if all(x is not None and x.shape == images[0].shape for x in images):
133
+ batch['images'] = torch.stack(images)
134
+ else:
135
+ batch['images'] = images
136
+ batch['conv_type'] = 'multi'
137
+ return batch
138
+
139
+ def convert_source_vicuna_format(sources):
140
+ new_sources = []
141
+ for source in sources:
142
+ new_source = []
143
+ for i, sentence in enumerate(source):
144
+ role_0_msg = sentence['q']
145
+ role_1_msg = sentence['a']
146
+ new_source.append({
147
+ 'from':'human',
148
+ 'value': role_0_msg,
149
+ })
150
+ new_source.append({
151
+ 'from':'gpt',
152
+ 'value': role_1_msg,
153
+ })
154
+ new_sources.append(new_source)
155
+ return new_sources
156
+
157
+ def preprocess_multimodal(
158
+ conversation_list: Sequence[str],
159
+ multimodal_cfg: dict,
160
+ cur_token_len: int,
161
+ msg=''
162
+ ) -> Dict:
163
+ # 将conversational list中
164
+ is_multimodal = True
165
+ # image_token_len = multimodal_cfg['image_token_len']
166
+ image_token_len = cur_token_len
167
+ conversation_list[0]["q"] = "<Video>"+DEFAULT_IMAGE_PATCH_TOKEN * image_token_len +"</Video> " + msg + conversation_list[0]["q"]
168
+ return [conversation_list]
169
+
170
+ def _add_speaker_and_signal(header, source, get_conversation=True):
171
+ """Add speaker and start/end signal on each round."""
172
+ BEGIN_SIGNAL = "###"
173
+ END_SIGNAL = "\n"
174
+ conversation = header
175
+ for sentence in source:
176
+ from_str = sentence["from"]
177
+ if from_str.lower() == "human":
178
+ from_str = video_conversation.roles[0]
179
+ elif from_str.lower() == "gpt":
180
+ from_str = video_conversation.roles[1]
181
+ else:
182
+ from_str = 'unknown'
183
+ sentence["value"] = (BEGIN_SIGNAL + from_str + ": " +
184
+ sentence["value"] + END_SIGNAL)
185
+ if get_conversation:
186
+ conversation += sentence["value"]
187
+ conversation += BEGIN_SIGNAL
188
+ return conversation
189
+
190
+ def _tokenize_fn(strings: Sequence[str],
191
+ tokenizer: transformers.PreTrainedTokenizer) -> Dict:
192
+ """Tokenize a list of strings."""
193
+ tokenized_list = [
194
+ tokenizer(
195
+ text,
196
+ return_tensors="pt",
197
+ padding="longest",
198
+ max_length=512,
199
+ truncation=True,
200
+ ) for text in strings
201
+ ]
202
+ input_ids = labels = [
203
+ tokenized.input_ids[0] for tokenized in tokenized_list
204
+ ]
205
+ input_ids_lens = labels_lens = [
206
+ tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item()
207
+ for tokenized in tokenized_list
208
+ ]
209
+ return dict(
210
+ input_ids=input_ids,
211
+ labels=labels,
212
+ input_ids_lens=input_ids_lens,
213
+ labels_lens=labels_lens,
214
+ )
215
+
216
+ def preprocess(
217
+ sources: Sequence[str],
218
+ tokenizer: transformers.PreTrainedTokenizer,
219
+ ) -> Dict:
220
+ """
221
+ Given a list of sources, each is a conversation list. This transform:
222
+ 1. Add signal '### ' at the beginning each sentence, with end signal '\n';
223
+ 2. Concatenate conversations together;
224
+ 3. Tokenize the concatenated conversation;
225
+ 4. Make a deepcopy as the target. Mask human words with IGNORE_INDEX.
226
+ """
227
+ # add end signal and concatenate together
228
+ conversations = []
229
+ for source in sources:
230
+ header = f"{video_conversation.system}\n\n"
231
+ conversation = _add_speaker_and_signal(header, source)
232
+ conversations.append(conversation)
233
+ # tokenize conversations
234
+ conversations_tokenized = _tokenize_fn(conversations, tokenizer)
235
+ input_ids = conversations_tokenized["input_ids"]
236
+ targets = copy.deepcopy(input_ids)
237
+ for target, source in zip(targets, sources):
238
+ tokenized_lens = _tokenize_fn([header] + [s["value"] for s in source],
239
+ tokenizer)["input_ids_lens"]
240
+ speakers = [sentence["from"] for sentence in source]
241
+ _mask_targets(target, tokenized_lens, speakers)
242
+
243
+ return dict(input_ids=input_ids, labels=targets)
244
+
245
+ def _mask_targets(target, tokenized_lens, speakers):
246
+ # cur_idx = 0
247
+ cur_idx = tokenized_lens[0]
248
+ tokenized_lens = tokenized_lens[1:]
249
+ target[:cur_idx] = IGNORE_INDEX
250
+ for tokenized_len, speaker in zip(tokenized_lens, speakers):
251
+ if speaker == "human":
252
+ target[cur_idx+2:cur_idx + tokenized_len] = IGNORE_INDEX
253
+ cur_idx += tokenized_len
video_llama/datasets/datasets/webvid_datasets.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import os
9
+ from video_llama.datasets.datasets.base_dataset import BaseDataset
10
+ from video_llama.datasets.datasets.caption_datasets import CaptionDataset
11
+ import pandas as pd
12
+ import decord
13
+ from decord import VideoReader
14
+ import random
15
+ import torch
16
+ from torch.utils.data.dataloader import default_collate
17
+ class WebvidDataset(BaseDataset):
18
+ def __init__(self, vis_processor, text_processor, vis_root, ann_root):
19
+ """
20
+ vis_root (string): Root directory of video (e.g. webvid_eval/video/)
21
+ ann_root (string): Root directory of video (e.g. webvid_eval/annotations/)
22
+ split (string): val or test
23
+ """
24
+ super().__init__(vis_processor=vis_processor, text_processor=text_processor)
25
+
26
+
27
+ # 读取一个路径下所有的
28
+
29
+ ts_df = []
30
+ for file_name in os.listdir(ann_root):
31
+ if file_name.endswith('.csv'):
32
+ df = pd.read_csv(os.path.join(ann_root, file_name))
33
+ ts_df.append(df)
34
+
35
+ merged_df = pd.concat(ts_df)
36
+ self.annotation = merged_df
37
+ self.vis_root = vis_root
38
+ self.resize_size = 224
39
+ self.num_frm = 8
40
+ self.frm_sampling_strategy = 'headtail'
41
+
42
+ def _get_video_path(self, sample):
43
+ rel_video_fp = os.path.join(sample['page_dir'], str(sample['videoid']) + '.mp4')
44
+ full_video_fp = os.path.join(self.vis_root, rel_video_fp)
45
+ return full_video_fp
46
+
47
+ def __getitem__(self, index):
48
+ num_retries = 10 # skip error videos
49
+ for _ in range(num_retries):
50
+ sample = self.annotation.iloc[index]
51
+ sample_dict = sample.to_dict()
52
+ video_id = sample_dict['videoid']
53
+
54
+ if 'name' in sample_dict.keys():
55
+ text = sample_dict['name'].strip()
56
+ else:
57
+ raise NotImplementedError("Un-supported text annotation format.")
58
+
59
+ # fetch video
60
+ video_path = self._get_video_path(sample_dict)
61
+ # if os.path.exists(video_path):
62
+ try:
63
+ video = self.vis_processor(video_path)
64
+ except:
65
+ print(f"Failed to load examples with video: {video_path}. "
66
+ f"Will randomly sample an example as a replacement.")
67
+ index = random.randint(0, len(self) - 1)
68
+ continue
69
+ caption = self.text_processor(text)
70
+
71
+ # print(video.size())
72
+ if video is None or caption is None \
73
+ or video.size()!=torch.Size([3,self.vis_processor.n_frms,224,224]):
74
+ print(f"Failed to load examples with video: {video_path}. "
75
+ f"Will randomly sample an example as a replacement.")
76
+ index = random.randint(0, len(self) - 1)
77
+ continue
78
+ else:
79
+ break
80
+ else:
81
+ raise RuntimeError(f"Failed to fetch video after {num_retries} retries.")
82
+ # "image_id" is kept to stay compatible with the COCO evaluation format
83
+ return {
84
+ "image": video,
85
+ "text_input": caption,
86
+ "type":'video',
87
+ }
88
+
89
+ def __len__(self):
90
+ return len(self.annotation)
91
+
92
+ # def collater(self, samples):
93
+ # new_result = {}
94
+ # new_result['image'] = default_collate( [sample["image"] for sample in samples])
95
+ # new_result['text_input'] = default_collate( [sample["text_input"] for sample in samples])
96
+ # return new_result
97
+
98
+ class WebvidDatasetEvalDataset(BaseDataset):
99
+ def __init__(self, vis_processor, text_processor, vis_root, ann_paths):
100
+ """
101
+ vis_root (string): Root directory of images (e.g. coco/images/)
102
+ ann_root (string): directory to store the annotation file
103
+ split (string): val or test
104
+ """
105
+ super().__init__(vis_processor, text_processor, vis_root, ann_paths)
106
+
107
+ def __getitem__(self, index):
108
+
109
+ ann = self.annotation[index]
110
+
111
+ vname = ann["video"]
112
+ video_path = os.path.join(self.vis_root, vname)
113
+
114
+ video = self.vis_processor(video_path)
115
+
116
+ return {
117
+ "video": video,
118
+ "image_id": ann["image_id"],
119
+ "instance_id": ann["instance_id"],
120
+ }
121
+
122
+
video_llama/models/Qformer.py ADDED
@@ -0,0 +1,1217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Adapted from salesforce@LAVIS. Below is the original copyright:
3
+ * Copyright (c) 2023, salesforce.com, inc.
4
+ * All rights reserved.
5
+ * SPDX-License-Identifier: BSD-3-Clause
6
+ * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
7
+ * By Junnan Li
8
+ * Based on huggingface code base
9
+ * https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/models/bert
10
+ """
11
+
12
+ import math
13
+ import os
14
+ import warnings
15
+ from dataclasses import dataclass
16
+ from typing import Optional, Tuple, Dict, Any
17
+
18
+ import torch
19
+ from torch import Tensor, device, dtype, nn
20
+ import torch.utils.checkpoint
21
+ from torch import nn
22
+ from torch.nn import CrossEntropyLoss
23
+ import torch.nn.functional as F
24
+
25
+ from transformers.activations import ACT2FN
26
+ from transformers.file_utils import (
27
+ ModelOutput,
28
+ )
29
+ from transformers.modeling_outputs import (
30
+ BaseModelOutputWithPastAndCrossAttentions,
31
+ BaseModelOutputWithPoolingAndCrossAttentions,
32
+ CausalLMOutputWithCrossAttentions,
33
+ MaskedLMOutput,
34
+ MultipleChoiceModelOutput,
35
+ NextSentencePredictorOutput,
36
+ QuestionAnsweringModelOutput,
37
+ SequenceClassifierOutput,
38
+ TokenClassifierOutput,
39
+ )
40
+ from transformers.modeling_utils import (
41
+ PreTrainedModel,
42
+ apply_chunking_to_forward,
43
+ find_pruneable_heads_and_indices,
44
+ prune_linear_layer,
45
+ )
46
+ from transformers.utils import logging
47
+ from transformers.models.bert.configuration_bert import BertConfig
48
+
49
+ logger = logging.get_logger(__name__)
50
+
51
+
52
+ class BertEmbeddings(nn.Module):
53
+ """Construct the embeddings from word and position embeddings."""
54
+
55
+ def __init__(self, config):
56
+ super().__init__()
57
+ self.word_embeddings = nn.Embedding(
58
+ config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id
59
+ )
60
+ self.position_embeddings = nn.Embedding(
61
+ config.max_position_embeddings, config.hidden_size
62
+ )
63
+
64
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
65
+ # any TensorFlow checkpoint file
66
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
67
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
68
+
69
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
70
+ self.register_buffer(
71
+ "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))
72
+ )
73
+ self.position_embedding_type = getattr(
74
+ config, "position_embedding_type", "absolute"
75
+ )
76
+
77
+ self.config = config
78
+
79
+ def forward(
80
+ self,
81
+ input_ids=None,
82
+ position_ids=None,
83
+ query_embeds=None,
84
+ past_key_values_length=0,
85
+ ):
86
+ if input_ids is not None:
87
+ seq_length = input_ids.size()[1]
88
+ else:
89
+ seq_length = 0
90
+
91
+ if position_ids is None:
92
+ position_ids = self.position_ids[
93
+ :, past_key_values_length : seq_length + past_key_values_length
94
+ ].clone()
95
+
96
+ if input_ids is not None:
97
+ embeddings = self.word_embeddings(input_ids)
98
+ if self.position_embedding_type == "absolute":
99
+ position_embeddings = self.position_embeddings(position_ids)
100
+ embeddings = embeddings + position_embeddings
101
+
102
+ if query_embeds is not None:
103
+ embeddings = torch.cat((query_embeds, embeddings), dim=1)
104
+ else:
105
+ embeddings = query_embeds
106
+
107
+ embeddings = self.LayerNorm(embeddings)
108
+ embeddings = self.dropout(embeddings)
109
+ return embeddings
110
+
111
+
112
+ class BertSelfAttention(nn.Module):
113
+ def __init__(self, config, is_cross_attention):
114
+ super().__init__()
115
+ self.config = config
116
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(
117
+ config, "embedding_size"
118
+ ):
119
+ raise ValueError(
120
+ "The hidden size (%d) is not a multiple of the number of attention "
121
+ "heads (%d)" % (config.hidden_size, config.num_attention_heads)
122
+ )
123
+
124
+ self.num_attention_heads = config.num_attention_heads
125
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
126
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
127
+
128
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
129
+ if is_cross_attention:
130
+ self.key = nn.Linear(config.encoder_width, self.all_head_size)
131
+ self.value = nn.Linear(config.encoder_width, self.all_head_size)
132
+ else:
133
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
134
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
135
+
136
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
137
+ self.position_embedding_type = getattr(
138
+ config, "position_embedding_type", "absolute"
139
+ )
140
+ if (
141
+ self.position_embedding_type == "relative_key"
142
+ or self.position_embedding_type == "relative_key_query"
143
+ ):
144
+ self.max_position_embeddings = config.max_position_embeddings
145
+ self.distance_embedding = nn.Embedding(
146
+ 2 * config.max_position_embeddings - 1, self.attention_head_size
147
+ )
148
+ self.save_attention = False
149
+
150
+ def save_attn_gradients(self, attn_gradients):
151
+ self.attn_gradients = attn_gradients
152
+
153
+ def get_attn_gradients(self):
154
+ return self.attn_gradients
155
+
156
+ def save_attention_map(self, attention_map):
157
+ self.attention_map = attention_map
158
+
159
+ def get_attention_map(self):
160
+ return self.attention_map
161
+
162
+ def transpose_for_scores(self, x):
163
+ new_x_shape = x.size()[:-1] + (
164
+ self.num_attention_heads,
165
+ self.attention_head_size,
166
+ )
167
+ x = x.view(*new_x_shape)
168
+ return x.permute(0, 2, 1, 3)
169
+
170
+ def forward(
171
+ self,
172
+ hidden_states,
173
+ attention_mask=None,
174
+ head_mask=None,
175
+ encoder_hidden_states=None,
176
+ encoder_attention_mask=None,
177
+ past_key_value=None,
178
+ output_attentions=False,
179
+ ):
180
+
181
+ # If this is instantiated as a cross-attention module, the keys
182
+ # and values come from an encoder; the attention mask needs to be
183
+ # such that the encoder's padding tokens are not attended to.
184
+ is_cross_attention = encoder_hidden_states is not None
185
+
186
+ if is_cross_attention:
187
+ key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
188
+ value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
189
+ attention_mask = encoder_attention_mask
190
+ elif past_key_value is not None:
191
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
192
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
193
+ key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
194
+ value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
195
+ else:
196
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
197
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
198
+
199
+ mixed_query_layer = self.query(hidden_states)
200
+
201
+ query_layer = self.transpose_for_scores(mixed_query_layer)
202
+
203
+ past_key_value = (key_layer, value_layer)
204
+
205
+ # Take the dot product between "query" and "key" to get the raw attention scores.
206
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
207
+
208
+ if (
209
+ self.position_embedding_type == "relative_key"
210
+ or self.position_embedding_type == "relative_key_query"
211
+ ):
212
+ seq_length = hidden_states.size()[1]
213
+ position_ids_l = torch.arange(
214
+ seq_length, dtype=torch.long, device=hidden_states.device
215
+ ).view(-1, 1)
216
+ position_ids_r = torch.arange(
217
+ seq_length, dtype=torch.long, device=hidden_states.device
218
+ ).view(1, -1)
219
+ distance = position_ids_l - position_ids_r
220
+ positional_embedding = self.distance_embedding(
221
+ distance + self.max_position_embeddings - 1
222
+ )
223
+ positional_embedding = positional_embedding.to(
224
+ dtype=query_layer.dtype
225
+ ) # fp16 compatibility
226
+
227
+ if self.position_embedding_type == "relative_key":
228
+ relative_position_scores = torch.einsum(
229
+ "bhld,lrd->bhlr", query_layer, positional_embedding
230
+ )
231
+ attention_scores = attention_scores + relative_position_scores
232
+ elif self.position_embedding_type == "relative_key_query":
233
+ relative_position_scores_query = torch.einsum(
234
+ "bhld,lrd->bhlr", query_layer, positional_embedding
235
+ )
236
+ relative_position_scores_key = torch.einsum(
237
+ "bhrd,lrd->bhlr", key_layer, positional_embedding
238
+ )
239
+ attention_scores = (
240
+ attention_scores
241
+ + relative_position_scores_query
242
+ + relative_position_scores_key
243
+ )
244
+
245
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
246
+ if attention_mask is not None:
247
+ # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
248
+ attention_scores = attention_scores + attention_mask
249
+
250
+ # Normalize the attention scores to probabilities.
251
+ attention_probs = nn.Softmax(dim=-1)(attention_scores)
252
+
253
+ if is_cross_attention and self.save_attention:
254
+ self.save_attention_map(attention_probs)
255
+ attention_probs.register_hook(self.save_attn_gradients)
256
+
257
+ # This is actually dropping out entire tokens to attend to, which might
258
+ # seem a bit unusual, but is taken from the original Transformer paper.
259
+ attention_probs_dropped = self.dropout(attention_probs)
260
+
261
+ # Mask heads if we want to
262
+ if head_mask is not None:
263
+ attention_probs_dropped = attention_probs_dropped * head_mask
264
+
265
+ context_layer = torch.matmul(attention_probs_dropped, value_layer)
266
+
267
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
268
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
269
+ context_layer = context_layer.view(*new_context_layer_shape)
270
+
271
+ outputs = (
272
+ (context_layer, attention_probs) if output_attentions else (context_layer,)
273
+ )
274
+
275
+ outputs = outputs + (past_key_value,)
276
+ return outputs
277
+
278
+
279
+ class BertSelfOutput(nn.Module):
280
+ def __init__(self, config):
281
+ super().__init__()
282
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
283
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
284
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
285
+
286
+ def forward(self, hidden_states, input_tensor):
287
+ hidden_states = self.dense(hidden_states)
288
+ hidden_states = self.dropout(hidden_states)
289
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
290
+ return hidden_states
291
+
292
+
293
+ class BertAttention(nn.Module):
294
+ def __init__(self, config, is_cross_attention=False):
295
+ super().__init__()
296
+ self.self = BertSelfAttention(config, is_cross_attention)
297
+ self.output = BertSelfOutput(config)
298
+ self.pruned_heads = set()
299
+
300
+ def prune_heads(self, heads):
301
+ if len(heads) == 0:
302
+ return
303
+ heads, index = find_pruneable_heads_and_indices(
304
+ heads,
305
+ self.self.num_attention_heads,
306
+ self.self.attention_head_size,
307
+ self.pruned_heads,
308
+ )
309
+
310
+ # Prune linear layers
311
+ self.self.query = prune_linear_layer(self.self.query, index)
312
+ self.self.key = prune_linear_layer(self.self.key, index)
313
+ self.self.value = prune_linear_layer(self.self.value, index)
314
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
315
+
316
+ # Update hyper params and store pruned heads
317
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
318
+ self.self.all_head_size = (
319
+ self.self.attention_head_size * self.self.num_attention_heads
320
+ )
321
+ self.pruned_heads = self.pruned_heads.union(heads)
322
+
323
+ def forward(
324
+ self,
325
+ hidden_states,
326
+ attention_mask=None,
327
+ head_mask=None,
328
+ encoder_hidden_states=None,
329
+ encoder_attention_mask=None,
330
+ past_key_value=None,
331
+ output_attentions=False,
332
+ ):
333
+ self_outputs = self.self(
334
+ hidden_states,
335
+ attention_mask,
336
+ head_mask,
337
+ encoder_hidden_states,
338
+ encoder_attention_mask,
339
+ past_key_value,
340
+ output_attentions,
341
+ )
342
+ attention_output = self.output(self_outputs[0], hidden_states)
343
+
344
+ outputs = (attention_output,) + self_outputs[
345
+ 1:
346
+ ] # add attentions if we output them
347
+ return outputs
348
+
349
+
350
+ class BertIntermediate(nn.Module):
351
+ def __init__(self, config):
352
+ super().__init__()
353
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
354
+ if isinstance(config.hidden_act, str):
355
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
356
+ else:
357
+ self.intermediate_act_fn = config.hidden_act
358
+
359
+ def forward(self, hidden_states):
360
+ hidden_states = self.dense(hidden_states)
361
+ hidden_states = self.intermediate_act_fn(hidden_states)
362
+ return hidden_states
363
+
364
+
365
+ class BertOutput(nn.Module):
366
+ def __init__(self, config):
367
+ super().__init__()
368
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
369
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
370
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
371
+
372
+ def forward(self, hidden_states, input_tensor):
373
+ hidden_states = self.dense(hidden_states)
374
+ hidden_states = self.dropout(hidden_states)
375
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
376
+ return hidden_states
377
+
378
+
379
+ class BertLayer(nn.Module):
380
+ def __init__(self, config, layer_num):
381
+ super().__init__()
382
+ self.config = config
383
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
384
+ self.seq_len_dim = 1
385
+ self.attention = BertAttention(config)
386
+ self.layer_num = layer_num
387
+ if (
388
+ self.config.add_cross_attention
389
+ and layer_num % self.config.cross_attention_freq == 0
390
+ ):
391
+ self.crossattention = BertAttention(
392
+ config, is_cross_attention=self.config.add_cross_attention
393
+ )
394
+ self.has_cross_attention = True
395
+ else:
396
+ self.has_cross_attention = False
397
+ self.intermediate = BertIntermediate(config)
398
+ self.output = BertOutput(config)
399
+
400
+ self.intermediate_query = BertIntermediate(config)
401
+ self.output_query = BertOutput(config)
402
+
403
+ def forward(
404
+ self,
405
+ hidden_states,
406
+ attention_mask=None,
407
+ head_mask=None,
408
+ encoder_hidden_states=None,
409
+ encoder_attention_mask=None,
410
+ past_key_value=None,
411
+ output_attentions=False,
412
+ query_length=0,
413
+ ):
414
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
415
+ self_attn_past_key_value = (
416
+ past_key_value[:2] if past_key_value is not None else None
417
+ )
418
+ self_attention_outputs = self.attention(
419
+ hidden_states,
420
+ attention_mask,
421
+ head_mask,
422
+ output_attentions=output_attentions,
423
+ past_key_value=self_attn_past_key_value,
424
+ )
425
+ attention_output = self_attention_outputs[0]
426
+ outputs = self_attention_outputs[1:-1]
427
+
428
+ present_key_value = self_attention_outputs[-1]
429
+
430
+ if query_length > 0:
431
+ query_attention_output = attention_output[:, :query_length, :]
432
+
433
+ if self.has_cross_attention:
434
+ assert (
435
+ encoder_hidden_states is not None
436
+ ), "encoder_hidden_states must be given for cross-attention layers"
437
+ cross_attention_outputs = self.crossattention(
438
+ query_attention_output,
439
+ attention_mask,
440
+ head_mask,
441
+ encoder_hidden_states,
442
+ encoder_attention_mask,
443
+ output_attentions=output_attentions,
444
+ )
445
+ query_attention_output = cross_attention_outputs[0]
446
+ outputs = (
447
+ outputs + cross_attention_outputs[1:-1]
448
+ ) # add cross attentions if we output attention weights
449
+
450
+ layer_output = apply_chunking_to_forward(
451
+ self.feed_forward_chunk_query,
452
+ self.chunk_size_feed_forward,
453
+ self.seq_len_dim,
454
+ query_attention_output,
455
+ )
456
+ if attention_output.shape[1] > query_length:
457
+ layer_output_text = apply_chunking_to_forward(
458
+ self.feed_forward_chunk,
459
+ self.chunk_size_feed_forward,
460
+ self.seq_len_dim,
461
+ attention_output[:, query_length:, :],
462
+ )
463
+ layer_output = torch.cat([layer_output, layer_output_text], dim=1)
464
+ else:
465
+ layer_output = apply_chunking_to_forward(
466
+ self.feed_forward_chunk,
467
+ self.chunk_size_feed_forward,
468
+ self.seq_len_dim,
469
+ attention_output,
470
+ )
471
+ outputs = (layer_output,) + outputs
472
+
473
+ outputs = outputs + (present_key_value,)
474
+
475
+ return outputs
476
+
477
+ def feed_forward_chunk(self, attention_output):
478
+ intermediate_output = self.intermediate(attention_output)
479
+ layer_output = self.output(intermediate_output, attention_output)
480
+ return layer_output
481
+
482
+ def feed_forward_chunk_query(self, attention_output):
483
+ intermediate_output = self.intermediate_query(attention_output)
484
+ layer_output = self.output_query(intermediate_output, attention_output)
485
+ return layer_output
486
+
487
+
488
+ class BertEncoder(nn.Module):
489
+ def __init__(self, config):
490
+ super().__init__()
491
+ self.config = config
492
+ self.layer = nn.ModuleList(
493
+ [BertLayer(config, i) for i in range(config.num_hidden_layers)]
494
+ )
495
+
496
+ def forward(
497
+ self,
498
+ hidden_states,
499
+ attention_mask=None,
500
+ head_mask=None,
501
+ encoder_hidden_states=None,
502
+ encoder_attention_mask=None,
503
+ past_key_values=None,
504
+ use_cache=None,
505
+ output_attentions=False,
506
+ output_hidden_states=False,
507
+ return_dict=True,
508
+ query_length=0,
509
+ ):
510
+ all_hidden_states = () if output_hidden_states else None
511
+ all_self_attentions = () if output_attentions else None
512
+ all_cross_attentions = (
513
+ () if output_attentions and self.config.add_cross_attention else None
514
+ )
515
+
516
+ next_decoder_cache = () if use_cache else None
517
+
518
+ for i in range(self.config.num_hidden_layers):
519
+ layer_module = self.layer[i]
520
+ if output_hidden_states:
521
+ all_hidden_states = all_hidden_states + (hidden_states,)
522
+
523
+ layer_head_mask = head_mask[i] if head_mask is not None else None
524
+ past_key_value = past_key_values[i] if past_key_values is not None else None
525
+
526
+ if getattr(self.config, "gradient_checkpointing", False) and self.training:
527
+
528
+ if use_cache:
529
+ logger.warn(
530
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
531
+ )
532
+ use_cache = False
533
+
534
+ def create_custom_forward(module):
535
+ def custom_forward(*inputs):
536
+ return module(
537
+ *inputs, past_key_value, output_attentions, query_length
538
+ )
539
+
540
+ return custom_forward
541
+
542
+ layer_outputs = torch.utils.checkpoint.checkpoint(
543
+ create_custom_forward(layer_module),
544
+ hidden_states,
545
+ attention_mask,
546
+ layer_head_mask,
547
+ encoder_hidden_states,
548
+ encoder_attention_mask,
549
+ )
550
+ else:
551
+ layer_outputs = layer_module(
552
+ hidden_states,
553
+ attention_mask,
554
+ layer_head_mask,
555
+ encoder_hidden_states,
556
+ encoder_attention_mask,
557
+ past_key_value,
558
+ output_attentions,
559
+ query_length,
560
+ )
561
+
562
+ hidden_states = layer_outputs[0]
563
+ if use_cache:
564
+ next_decoder_cache += (layer_outputs[-1],)
565
+ if output_attentions:
566
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
567
+ all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
568
+
569
+ if output_hidden_states:
570
+ all_hidden_states = all_hidden_states + (hidden_states,)
571
+
572
+ if not return_dict:
573
+ return tuple(
574
+ v
575
+ for v in [
576
+ hidden_states,
577
+ next_decoder_cache,
578
+ all_hidden_states,
579
+ all_self_attentions,
580
+ all_cross_attentions,
581
+ ]
582
+ if v is not None
583
+ )
584
+ return BaseModelOutputWithPastAndCrossAttentions(
585
+ last_hidden_state=hidden_states,
586
+ past_key_values=next_decoder_cache,
587
+ hidden_states=all_hidden_states,
588
+ attentions=all_self_attentions,
589
+ cross_attentions=all_cross_attentions,
590
+ )
591
+
592
+
593
+ class BertPooler(nn.Module):
594
+ def __init__(self, config):
595
+ super().__init__()
596
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
597
+ self.activation = nn.Tanh()
598
+
599
+ def forward(self, hidden_states):
600
+ # We "pool" the model by simply taking the hidden state corresponding
601
+ # to the first token.
602
+ first_token_tensor = hidden_states[:, 0]
603
+ pooled_output = self.dense(first_token_tensor)
604
+ pooled_output = self.activation(pooled_output)
605
+ return pooled_output
606
+
607
+
608
+ class BertPredictionHeadTransform(nn.Module):
609
+ def __init__(self, config):
610
+ super().__init__()
611
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
612
+ if isinstance(config.hidden_act, str):
613
+ self.transform_act_fn = ACT2FN[config.hidden_act]
614
+ else:
615
+ self.transform_act_fn = config.hidden_act
616
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
617
+
618
+ def forward(self, hidden_states):
619
+ hidden_states = self.dense(hidden_states)
620
+ hidden_states = self.transform_act_fn(hidden_states)
621
+ hidden_states = self.LayerNorm(hidden_states)
622
+ return hidden_states
623
+
624
+
625
+ class BertLMPredictionHead(nn.Module):
626
+ def __init__(self, config):
627
+ super().__init__()
628
+ self.transform = BertPredictionHeadTransform(config)
629
+
630
+ # The output weights are the same as the input embeddings, but there is
631
+ # an output-only bias for each token.
632
+ self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
633
+
634
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
635
+
636
+ # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
637
+ self.decoder.bias = self.bias
638
+
639
+ def forward(self, hidden_states):
640
+ hidden_states = self.transform(hidden_states)
641
+ hidden_states = self.decoder(hidden_states)
642
+ return hidden_states
643
+
644
+
645
+ class BertOnlyMLMHead(nn.Module):
646
+ def __init__(self, config):
647
+ super().__init__()
648
+ self.predictions = BertLMPredictionHead(config)
649
+
650
+ def forward(self, sequence_output):
651
+ prediction_scores = self.predictions(sequence_output)
652
+ return prediction_scores
653
+
654
+
655
+ class BertPreTrainedModel(PreTrainedModel):
656
+ """
657
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
658
+ models.
659
+ """
660
+
661
+ config_class = BertConfig
662
+ base_model_prefix = "bert"
663
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
664
+
665
+ def _init_weights(self, module):
666
+ """Initialize the weights"""
667
+ if isinstance(module, (nn.Linear, nn.Embedding)):
668
+ # Slightly different from the TF version which uses truncated_normal for initialization
669
+ # cf https://github.com/pytorch/pytorch/pull/5617
670
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
671
+ elif isinstance(module, nn.LayerNorm):
672
+ module.bias.data.zero_()
673
+ module.weight.data.fill_(1.0)
674
+ if isinstance(module, nn.Linear) and module.bias is not None:
675
+ module.bias.data.zero_()
676
+
677
+
678
+ class BertModel(BertPreTrainedModel):
679
+ """
680
+ The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
681
+ cross-attention is added between the self-attention layers, following the architecture described in `Attention is
682
+ all you need <https://arxiv.org/abs/1706.03762>`__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
683
+ Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
684
+ argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an
685
+ input to the forward pass.
686
+ """
687
+
688
+ def __init__(self, config, add_pooling_layer=False):
689
+ super().__init__(config)
690
+ self.config = config
691
+
692
+ self.embeddings = BertEmbeddings(config)
693
+
694
+ self.encoder = BertEncoder(config)
695
+
696
+ self.pooler = BertPooler(config) if add_pooling_layer else None
697
+
698
+ self.init_weights()
699
+
700
+ def get_input_embeddings(self):
701
+ return self.embeddings.word_embeddings
702
+
703
+ def set_input_embeddings(self, value):
704
+ self.embeddings.word_embeddings = value
705
+
706
+ def _prune_heads(self, heads_to_prune):
707
+ """
708
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
709
+ class PreTrainedModel
710
+ """
711
+ for layer, heads in heads_to_prune.items():
712
+ self.encoder.layer[layer].attention.prune_heads(heads)
713
+
714
+ def get_extended_attention_mask(
715
+ self,
716
+ attention_mask: Tensor,
717
+ input_shape: Tuple[int],
718
+ device: device,
719
+ is_decoder: bool,
720
+ has_query: bool = False,
721
+ ) -> Tensor:
722
+ """
723
+ Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
724
+
725
+ Arguments:
726
+ attention_mask (:obj:`torch.Tensor`):
727
+ Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
728
+ input_shape (:obj:`Tuple[int]`):
729
+ The shape of the input to the model.
730
+ device: (:obj:`torch.device`):
731
+ The device of the input to the model.
732
+
733
+ Returns:
734
+ :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`.
735
+ """
736
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
737
+ # ourselves in which case we just need to make it broadcastable to all heads.
738
+ if attention_mask.dim() == 3:
739
+ extended_attention_mask = attention_mask[:, None, :, :]
740
+ elif attention_mask.dim() == 2:
741
+ # Provided a padding mask of dimensions [batch_size, seq_length]
742
+ # - if the model is a decoder, apply a causal mask in addition to the padding mask
743
+ # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
744
+ if is_decoder:
745
+ batch_size, seq_length = input_shape
746
+
747
+ seq_ids = torch.arange(seq_length, device=device)
748
+ causal_mask = (
749
+ seq_ids[None, None, :].repeat(batch_size, seq_length, 1)
750
+ <= seq_ids[None, :, None]
751
+ )
752
+
753
+ # add a prefix ones mask to the causal mask
754
+ # causal and attention masks must have same type with pytorch version < 1.3
755
+ causal_mask = causal_mask.to(attention_mask.dtype)
756
+
757
+ if causal_mask.shape[1] < attention_mask.shape[1]:
758
+ prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
759
+ if has_query: # UniLM style attention mask
760
+ causal_mask = torch.cat(
761
+ [
762
+ torch.zeros(
763
+ (batch_size, prefix_seq_len, seq_length),
764
+ device=device,
765
+ dtype=causal_mask.dtype,
766
+ ),
767
+ causal_mask,
768
+ ],
769
+ axis=1,
770
+ )
771
+ causal_mask = torch.cat(
772
+ [
773
+ torch.ones(
774
+ (batch_size, causal_mask.shape[1], prefix_seq_len),
775
+ device=device,
776
+ dtype=causal_mask.dtype,
777
+ ),
778
+ causal_mask,
779
+ ],
780
+ axis=-1,
781
+ )
782
+ extended_attention_mask = (
783
+ causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
784
+ )
785
+ else:
786
+ extended_attention_mask = attention_mask[:, None, None, :]
787
+ else:
788
+ raise ValueError(
789
+ "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
790
+ input_shape, attention_mask.shape
791
+ )
792
+ )
793
+
794
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
795
+ # masked positions, this operation will create a tensor which is 0.0 for
796
+ # positions we want to attend and -10000.0 for masked positions.
797
+ # Since we are adding it to the raw scores before the softmax, this is
798
+ # effectively the same as removing these entirely.
799
+ extended_attention_mask = extended_attention_mask.to(
800
+ dtype=self.dtype
801
+ ) # fp16 compatibility
802
+ extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
803
+ return extended_attention_mask
804
+
805
+ def forward(
806
+ self,
807
+ input_ids=None,
808
+ attention_mask=None,
809
+ position_ids=None,
810
+ head_mask=None,
811
+ query_embeds=None,
812
+ encoder_hidden_states=None,
813
+ encoder_attention_mask=None,
814
+ past_key_values=None,
815
+ use_cache=None,
816
+ output_attentions=None,
817
+ output_hidden_states=None,
818
+ return_dict=None,
819
+ is_decoder=False,
820
+ ):
821
+ r"""
822
+ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
823
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
824
+ the model is configured as a decoder.
825
+ encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
826
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
827
+ the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
828
+ - 1 for tokens that are **not masked**,
829
+ - 0 for tokens that are **masked**.
830
+ past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
831
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
832
+ If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
833
+ (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
834
+ instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
835
+ use_cache (:obj:`bool`, `optional`):
836
+ If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
837
+ decoding (see :obj:`past_key_values`).
838
+ """
839
+ output_attentions = (
840
+ output_attentions
841
+ if output_attentions is not None
842
+ else self.config.output_attentions
843
+ )
844
+ output_hidden_states = (
845
+ output_hidden_states
846
+ if output_hidden_states is not None
847
+ else self.config.output_hidden_states
848
+ )
849
+ return_dict = (
850
+ return_dict if return_dict is not None else self.config.use_return_dict
851
+ )
852
+
853
+ # use_cache = use_cache if use_cache is not None else self.config.use_cache
854
+
855
+ if input_ids is None:
856
+ assert (
857
+ query_embeds is not None
858
+ ), "You have to specify query_embeds when input_ids is None"
859
+
860
+ # past_key_values_length
861
+ past_key_values_length = (
862
+ past_key_values[0][0].shape[2] - self.config.query_length
863
+ if past_key_values is not None
864
+ else 0
865
+ )
866
+
867
+ query_length = query_embeds.shape[1] if query_embeds is not None else 0
868
+
869
+ embedding_output = self.embeddings(
870
+ input_ids=input_ids,
871
+ position_ids=position_ids,
872
+ query_embeds=query_embeds,
873
+ past_key_values_length=past_key_values_length,
874
+ )
875
+
876
+ input_shape = embedding_output.size()[:-1]
877
+ batch_size, seq_length = input_shape
878
+ device = embedding_output.device
879
+
880
+ if attention_mask is None:
881
+ attention_mask = torch.ones(
882
+ ((batch_size, seq_length + past_key_values_length)), device=device
883
+ )
884
+
885
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
886
+ # ourselves in which case we just need to make it broadcastable to all heads.
887
+ if is_decoder:
888
+ extended_attention_mask = self.get_extended_attention_mask(
889
+ attention_mask,
890
+ input_ids.shape,
891
+ device,
892
+ is_decoder,
893
+ has_query=(query_embeds is not None),
894
+ )
895
+ else:
896
+ extended_attention_mask = self.get_extended_attention_mask(
897
+ attention_mask, input_shape, device, is_decoder
898
+ )
899
+
900
+ # If a 2D or 3D attention mask is provided for the cross-attention
901
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
902
+ if encoder_hidden_states is not None:
903
+ if type(encoder_hidden_states) == list:
904
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[
905
+ 0
906
+ ].size()
907
+ else:
908
+ (
909
+ encoder_batch_size,
910
+ encoder_sequence_length,
911
+ _,
912
+ ) = encoder_hidden_states.size()
913
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
914
+
915
+ if type(encoder_attention_mask) == list:
916
+ encoder_extended_attention_mask = [
917
+ self.invert_attention_mask(mask) for mask in encoder_attention_mask
918
+ ]
919
+ elif encoder_attention_mask is None:
920
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
921
+ encoder_extended_attention_mask = self.invert_attention_mask(
922
+ encoder_attention_mask
923
+ )
924
+ else:
925
+ encoder_extended_attention_mask = self.invert_attention_mask(
926
+ encoder_attention_mask
927
+ )
928
+ else:
929
+ encoder_extended_attention_mask = None
930
+
931
+ # Prepare head mask if needed
932
+ # 1.0 in head_mask indicate we keep the head
933
+ # attention_probs has shape bsz x n_heads x N x N
934
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
935
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
936
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
937
+
938
+ encoder_outputs = self.encoder(
939
+ embedding_output,
940
+ attention_mask=extended_attention_mask,
941
+ head_mask=head_mask,
942
+ encoder_hidden_states=encoder_hidden_states,
943
+ encoder_attention_mask=encoder_extended_attention_mask,
944
+ past_key_values=past_key_values,
945
+ use_cache=use_cache,
946
+ output_attentions=output_attentions,
947
+ output_hidden_states=output_hidden_states,
948
+ return_dict=return_dict,
949
+ query_length=query_length,
950
+ )
951
+ sequence_output = encoder_outputs[0]
952
+ pooled_output = (
953
+ self.pooler(sequence_output) if self.pooler is not None else None
954
+ )
955
+
956
+ if not return_dict:
957
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
958
+
959
+ return BaseModelOutputWithPoolingAndCrossAttentions(
960
+ last_hidden_state=sequence_output,
961
+ pooler_output=pooled_output,
962
+ past_key_values=encoder_outputs.past_key_values,
963
+ hidden_states=encoder_outputs.hidden_states,
964
+ attentions=encoder_outputs.attentions,
965
+ cross_attentions=encoder_outputs.cross_attentions,
966
+ )
967
+
968
+
969
+ class BertLMHeadModel(BertPreTrainedModel):
970
+
971
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
972
+ _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
973
+
974
+ def __init__(self, config):
975
+ super().__init__(config)
976
+
977
+ self.bert = BertModel(config, add_pooling_layer=False)
978
+ self.cls = BertOnlyMLMHead(config)
979
+
980
+ self.init_weights()
981
+
982
+ def get_output_embeddings(self):
983
+ return self.cls.predictions.decoder
984
+
985
+ def set_output_embeddings(self, new_embeddings):
986
+ self.cls.predictions.decoder = new_embeddings
987
+
988
+ def forward(
989
+ self,
990
+ input_ids=None,
991
+ attention_mask=None,
992
+ position_ids=None,
993
+ head_mask=None,
994
+ query_embeds=None,
995
+ encoder_hidden_states=None,
996
+ encoder_attention_mask=None,
997
+ labels=None,
998
+ past_key_values=None,
999
+ use_cache=True,
1000
+ output_attentions=None,
1001
+ output_hidden_states=None,
1002
+ return_dict=None,
1003
+ return_logits=False,
1004
+ is_decoder=True,
1005
+ reduction="mean",
1006
+ ):
1007
+ r"""
1008
+ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
1009
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
1010
+ the model is configured as a decoder.
1011
+ encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
1012
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
1013
+ the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
1014
+ - 1 for tokens that are **not masked**,
1015
+ - 0 for tokens that are **masked**.
1016
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
1017
+ Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
1018
+ ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are
1019
+ ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]``
1020
+ past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
1021
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
1022
+ If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
1023
+ (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
1024
+ instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
1025
+ use_cache (:obj:`bool`, `optional`):
1026
+ If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
1027
+ decoding (see :obj:`past_key_values`).
1028
+ Returns:
1029
+ Example::
1030
+ >>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig
1031
+ >>> import torch
1032
+ >>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
1033
+ >>> config = BertConfig.from_pretrained("bert-base-cased")
1034
+ >>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config)
1035
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
1036
+ >>> outputs = model(**inputs)
1037
+ >>> prediction_logits = outputs.logits
1038
+ """
1039
+ return_dict = (
1040
+ return_dict if return_dict is not None else self.config.use_return_dict
1041
+ )
1042
+ if labels is not None:
1043
+ use_cache = False
1044
+ if past_key_values is not None:
1045
+ query_embeds = None
1046
+
1047
+ outputs = self.bert(
1048
+ input_ids,
1049
+ attention_mask=attention_mask,
1050
+ position_ids=position_ids,
1051
+ head_mask=head_mask,
1052
+ query_embeds=query_embeds,
1053
+ encoder_hidden_states=encoder_hidden_states,
1054
+ encoder_attention_mask=encoder_attention_mask,
1055
+ past_key_values=past_key_values,
1056
+ use_cache=use_cache,
1057
+ output_attentions=output_attentions,
1058
+ output_hidden_states=output_hidden_states,
1059
+ return_dict=return_dict,
1060
+ is_decoder=is_decoder,
1061
+ )
1062
+
1063
+ sequence_output = outputs[0]
1064
+ if query_embeds is not None:
1065
+ sequence_output = outputs[0][:, query_embeds.shape[1] :, :]
1066
+
1067
+ prediction_scores = self.cls(sequence_output)
1068
+
1069
+ if return_logits:
1070
+ return prediction_scores[:, :-1, :].contiguous()
1071
+
1072
+ lm_loss = None
1073
+ if labels is not None:
1074
+ # we are doing next-token prediction; shift prediction scores and input ids by one
1075
+ shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
1076
+ labels = labels[:, 1:].contiguous()
1077
+ loss_fct = CrossEntropyLoss(reduction=reduction, label_smoothing=0.1)
1078
+ lm_loss = loss_fct(
1079
+ shifted_prediction_scores.view(-1, self.config.vocab_size),
1080
+ labels.view(-1),
1081
+ )
1082
+ if reduction == "none":
1083
+ lm_loss = lm_loss.view(prediction_scores.size(0), -1).sum(1)
1084
+
1085
+ if not return_dict:
1086
+ output = (prediction_scores,) + outputs[2:]
1087
+ return ((lm_loss,) + output) if lm_loss is not None else output
1088
+
1089
+ return CausalLMOutputWithCrossAttentions(
1090
+ loss=lm_loss,
1091
+ logits=prediction_scores,
1092
+ past_key_values=outputs.past_key_values,
1093
+ hidden_states=outputs.hidden_states,
1094
+ attentions=outputs.attentions,
1095
+ cross_attentions=outputs.cross_attentions,
1096
+ )
1097
+
1098
+ def prepare_inputs_for_generation(
1099
+ self, input_ids, query_embeds, past=None, attention_mask=None, **model_kwargs
1100
+ ):
1101
+ # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
1102
+ if attention_mask is None:
1103
+ attention_mask = input_ids.new_ones(input_ids.shape)
1104
+ query_mask = input_ids.new_ones(query_embeds.shape[:-1])
1105
+ attention_mask = torch.cat([query_mask, attention_mask], dim=-1)
1106
+
1107
+ # cut decoder_input_ids if past is used
1108
+ if past is not None:
1109
+ input_ids = input_ids[:, -1:]
1110
+
1111
+ return {
1112
+ "input_ids": input_ids,
1113
+ "query_embeds": query_embeds,
1114
+ "attention_mask": attention_mask,
1115
+ "past_key_values": past,
1116
+ "encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None),
1117
+ "encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None),
1118
+ "is_decoder": True,
1119
+ }
1120
+
1121
+ def _reorder_cache(self, past, beam_idx):
1122
+ reordered_past = ()
1123
+ for layer_past in past:
1124
+ reordered_past += (
1125
+ tuple(
1126
+ past_state.index_select(0, beam_idx) for past_state in layer_past
1127
+ ),
1128
+ )
1129
+ return reordered_past
1130
+
1131
+
1132
+ class BertForMaskedLM(BertPreTrainedModel):
1133
+
1134
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
1135
+ _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
1136
+
1137
+ def __init__(self, config):
1138
+ super().__init__(config)
1139
+
1140
+ self.bert = BertModel(config, add_pooling_layer=False)
1141
+ self.cls = BertOnlyMLMHead(config)
1142
+
1143
+ self.init_weights()
1144
+
1145
+ def get_output_embeddings(self):
1146
+ return self.cls.predictions.decoder
1147
+
1148
+ def set_output_embeddings(self, new_embeddings):
1149
+ self.cls.predictions.decoder = new_embeddings
1150
+
1151
+ def forward(
1152
+ self,
1153
+ input_ids=None,
1154
+ attention_mask=None,
1155
+ position_ids=None,
1156
+ head_mask=None,
1157
+ query_embeds=None,
1158
+ encoder_hidden_states=None,
1159
+ encoder_attention_mask=None,
1160
+ labels=None,
1161
+ output_attentions=None,
1162
+ output_hidden_states=None,
1163
+ return_dict=None,
1164
+ return_logits=False,
1165
+ is_decoder=False,
1166
+ ):
1167
+ r"""
1168
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
1169
+ Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ...,
1170
+ config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored
1171
+ (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``
1172
+ """
1173
+
1174
+ return_dict = (
1175
+ return_dict if return_dict is not None else self.config.use_return_dict
1176
+ )
1177
+
1178
+ outputs = self.bert(
1179
+ input_ids,
1180
+ attention_mask=attention_mask,
1181
+ position_ids=position_ids,
1182
+ head_mask=head_mask,
1183
+ query_embeds=query_embeds,
1184
+ encoder_hidden_states=encoder_hidden_states,
1185
+ encoder_attention_mask=encoder_attention_mask,
1186
+ output_attentions=output_attentions,
1187
+ output_hidden_states=output_hidden_states,
1188
+ return_dict=return_dict,
1189
+ is_decoder=is_decoder,
1190
+ )
1191
+
1192
+ if query_embeds is not None:
1193
+ sequence_output = outputs[0][:, query_embeds.shape[1] :, :]
1194
+ prediction_scores = self.cls(sequence_output)
1195
+
1196
+ if return_logits:
1197
+ return prediction_scores
1198
+
1199
+ masked_lm_loss = None
1200
+ if labels is not None:
1201
+ loss_fct = CrossEntropyLoss() # -100 index = padding token
1202
+ masked_lm_loss = loss_fct(
1203
+ prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)
1204
+ )
1205
+
1206
+ if not return_dict:
1207
+ output = (prediction_scores,) + outputs[2:]
1208
+ return (
1209
+ ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
1210
+ )
1211
+
1212
+ return MaskedLMOutput(
1213
+ loss=masked_lm_loss,
1214
+ logits=prediction_scores,
1215
+ hidden_states=outputs.hidden_states,
1216
+ attentions=outputs.attentions,
1217
+ )
video_llama/models/__init__.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Adapted from salesforce@LAVIS Vision-CAIR@MiniGPT-4. Below is the original copyright:
3
+ Copyright (c) 2022, salesforce.com, inc.
4
+ All rights reserved.
5
+ SPDX-License-Identifier: BSD-3-Clause
6
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
7
+ """
8
+
9
+ import logging
10
+ import torch
11
+ from omegaconf import OmegaConf
12
+
13
+ from video_llama.common.registry import registry
14
+ from video_llama.models.base_model import BaseModel
15
+ from video_llama.models.blip2 import Blip2Base
16
+ from video_llama.models.video_llama import VideoLLAMA
17
+ from video_llama.processors.base_processor import BaseProcessor
18
+
19
+
20
+ __all__ = [
21
+ "load_model",
22
+ "BaseModel",
23
+ "Blip2Base",
24
+ "VideoLLAMA"
25
+ ]
26
+
27
+
28
+ def load_model(name, model_type, is_eval=False, device="cpu", checkpoint=None):
29
+ """
30
+ Load supported models.
31
+
32
+ To list all available models and types in registry:
33
+ >>> from video_llama.models import model_zoo
34
+ >>> print(model_zoo)
35
+
36
+ Args:
37
+ name (str): name of the model.
38
+ model_type (str): type of the model.
39
+ is_eval (bool): whether the model is in eval mode. Default: False.
40
+ device (str): device to use. Default: "cpu".
41
+ checkpoint (str): path or to checkpoint. Default: None.
42
+ Note that expecting the checkpoint to have the same keys in state_dict as the model.
43
+
44
+ Returns:
45
+ model (torch.nn.Module): model.
46
+ """
47
+
48
+ model = registry.get_model_class(name).from_pretrained(model_type=model_type)
49
+
50
+ if checkpoint is not None:
51
+ model.load_checkpoint(checkpoint)
52
+
53
+ if is_eval:
54
+ model.eval()
55
+
56
+ if device == "cpu":
57
+ model = model.float()
58
+
59
+ return model.to(device)
60
+
61
+
62
+ def load_preprocess(config):
63
+ """
64
+ Load preprocessor configs and construct preprocessors.
65
+
66
+ If no preprocessor is specified, return BaseProcessor, which does not do any preprocessing.
67
+
68
+ Args:
69
+ config (dict): preprocessor configs.
70
+
71
+ Returns:
72
+ vis_processors (dict): preprocessors for visual inputs.
73
+ txt_processors (dict): preprocessors for text inputs.
74
+
75
+ Key is "train" or "eval" for processors used in training and evaluation respectively.
76
+ """
77
+
78
+ def _build_proc_from_cfg(cfg):
79
+ return (
80
+ registry.get_processor_class(cfg.name).from_config(cfg)
81
+ if cfg is not None
82
+ else BaseProcessor()
83
+ )
84
+
85
+ vis_processors = dict()
86
+ txt_processors = dict()
87
+
88
+ vis_proc_cfg = config.get("vis_processor")
89
+ txt_proc_cfg = config.get("text_processor")
90
+
91
+ if vis_proc_cfg is not None:
92
+ vis_train_cfg = vis_proc_cfg.get("train")
93
+ vis_eval_cfg = vis_proc_cfg.get("eval")
94
+ else:
95
+ vis_train_cfg = None
96
+ vis_eval_cfg = None
97
+
98
+ vis_processors["train"] = _build_proc_from_cfg(vis_train_cfg)
99
+ vis_processors["eval"] = _build_proc_from_cfg(vis_eval_cfg)
100
+
101
+ if txt_proc_cfg is not None:
102
+ txt_train_cfg = txt_proc_cfg.get("train")
103
+ txt_eval_cfg = txt_proc_cfg.get("eval")
104
+ else:
105
+ txt_train_cfg = None
106
+ txt_eval_cfg = None
107
+
108
+ txt_processors["train"] = _build_proc_from_cfg(txt_train_cfg)
109
+ txt_processors["eval"] = _build_proc_from_cfg(txt_eval_cfg)
110
+
111
+ return vis_processors, txt_processors
112
+
113
+
114
+ def load_model_and_preprocess(name, model_type, is_eval=False, device="cpu"):
115
+ """
116
+ Load model and its related preprocessors.
117
+
118
+ List all available models and types in registry:
119
+ >>> from video_llama.models import model_zoo
120
+ >>> print(model_zoo)
121
+
122
+ Args:
123
+ name (str): name of the model.
124
+ model_type (str): type of the model.
125
+ is_eval (bool): whether the model is in eval mode. Default: False.
126
+ device (str): device to use. Default: "cpu".
127
+
128
+ Returns:
129
+ model (torch.nn.Module): model.
130
+ vis_processors (dict): preprocessors for visual inputs.
131
+ txt_processors (dict): preprocessors for text inputs.
132
+ """
133
+ model_cls = registry.get_model_class(name)
134
+
135
+ # load model
136
+ model = model_cls.from_pretrained(model_type=model_type)
137
+
138
+ if is_eval:
139
+ model.eval()
140
+
141
+ # load preprocess
142
+ cfg = OmegaConf.load(model_cls.default_config_path(model_type))
143
+ if cfg is not None:
144
+ preprocess_cfg = cfg.preprocess
145
+
146
+ vis_processors, txt_processors = load_preprocess(preprocess_cfg)
147
+ else:
148
+ vis_processors, txt_processors = None, None
149
+ logging.info(
150
+ f"""No default preprocess for model {name} ({model_type}).
151
+ This can happen if the model is not finetuned on downstream datasets,
152
+ or it is not intended for direct use without finetuning.
153
+ """
154
+ )
155
+
156
+ if device == "cpu" or device == torch.device("cpu"):
157
+ model = model.float()
158
+
159
+ return model.to(device), vis_processors, txt_processors
160
+
161
+
162
+ class ModelZoo:
163
+ """
164
+ A utility class to create string representation of available model architectures and types.
165
+
166
+ >>> from video_llama.models import model_zoo
167
+ >>> # list all available models
168
+ >>> print(model_zoo)
169
+ >>> # show total number of models
170
+ >>> print(len(model_zoo))
171
+ """
172
+
173
+ def __init__(self) -> None:
174
+ self.model_zoo = {
175
+ k: list(v.PRETRAINED_MODEL_CONFIG_DICT.keys())
176
+ for k, v in registry.mapping["model_name_mapping"].items()
177
+ }
178
+
179
+ def __str__(self) -> str:
180
+ return (
181
+ "=" * 50
182
+ + "\n"
183
+ + f"{'Architectures':<30} {'Types'}\n"
184
+ + "=" * 50
185
+ + "\n"
186
+ + "\n".join(
187
+ [
188
+ f"{name:<30} {', '.join(types)}"
189
+ for name, types in self.model_zoo.items()
190
+ ]
191
+ )
192
+ )
193
+
194
+ def __iter__(self):
195
+ return iter(self.model_zoo.items())
196
+
197
+ def __len__(self):
198
+ return sum([len(v) for v in self.model_zoo.values()])
199
+
200
+
201
+ model_zoo = ModelZoo()