Reuben Tan commited on
Commit
b2afdba
1 Parent(s): 205dfac

initial commit

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. app.py +249 -0
  2. ckpt/VL_LLaMA_2_7B_Finetuned.pth +3 -0
  3. ckpt/finetuned_model.pth +3 -0
  4. demo_video.py +249 -0
  5. eval_configs/conversation_demo.yaml +78 -0
  6. global_local/__init__.py +31 -0
  7. global_local/__pycache__/__init__.cpython-39.pyc +0 -0
  8. global_local/common/__init__.py +0 -0
  9. global_local/common/__pycache__/__init__.cpython-39.pyc +0 -0
  10. global_local/common/__pycache__/config.cpython-39.pyc +0 -0
  11. global_local/common/__pycache__/dist_utils.cpython-39.pyc +0 -0
  12. global_local/common/__pycache__/logger.cpython-39.pyc +0 -0
  13. global_local/common/__pycache__/optims.cpython-39.pyc +0 -0
  14. global_local/common/__pycache__/registry.cpython-39.pyc +0 -0
  15. global_local/common/__pycache__/utils.cpython-39.pyc +0 -0
  16. global_local/common/config.py +468 -0
  17. global_local/common/dist_utils.py +156 -0
  18. global_local/common/gradcam.py +24 -0
  19. global_local/common/logger.py +195 -0
  20. global_local/common/optims.py +134 -0
  21. global_local/common/registry.py +329 -0
  22. global_local/common/utils.py +424 -0
  23. global_local/configs/datasets/cc_sbu/align.yaml +5 -0
  24. global_local/configs/datasets/cc_sbu/defaults.yaml +5 -0
  25. global_local/configs/datasets/instruct/llava_instruct.yaml +6 -0
  26. global_local/configs/datasets/instruct/webvid_instruct.yaml +6 -0
  27. global_local/configs/datasets/laion/defaults.yaml +5 -0
  28. global_local/configs/datasets/webvid/defaults.yaml +6 -0
  29. global_local/configs/default.yaml +5 -0
  30. global_local/configs/models/minigpt4.yaml +33 -0
  31. global_local/configs/models/video_llama.yaml +36 -0
  32. global_local/conversation/__init__.py +0 -0
  33. global_local/conversation/__pycache__/__init__.cpython-39.pyc +0 -0
  34. global_local/conversation/__pycache__/conversation_video.cpython-39.pyc +0 -0
  35. global_local/conversation/conversation_video.py +404 -0
  36. global_local/datasets/__init__.py +0 -0
  37. global_local/datasets/__pycache__/__init__.cpython-39.pyc +0 -0
  38. global_local/datasets/__pycache__/data_utils.cpython-39.pyc +0 -0
  39. global_local/datasets/builders/__init__.py +77 -0
  40. global_local/datasets/builders/__pycache__/__init__.cpython-39.pyc +0 -0
  41. global_local/datasets/builders/__pycache__/base_dataset_builder.cpython-39.pyc +0 -0
  42. global_local/datasets/builders/__pycache__/image_text_pair_builder.cpython-39.pyc +0 -0
  43. global_local/datasets/builders/__pycache__/instruct_builder.cpython-39.pyc +0 -0
  44. global_local/datasets/builders/__pycache__/video_caption_builder.cpython-39.pyc +0 -0
  45. global_local/datasets/builders/base_dataset_builder.py +236 -0
  46. global_local/datasets/builders/image_text_pair_builder.py +106 -0
  47. global_local/datasets/builders/instruct_builder.py +78 -0
  48. global_local/datasets/builders/video_caption_builder.py +34 -0
  49. global_local/datasets/data_utils.py +196 -0
  50. global_local/datasets/datasets/__init__.py +0 -0
app.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Adapted from: https://github.com/Vision-CAIR/MiniGPT-4/blob/main/demo.py
3
+ """
4
+ import argparse
5
+ import os
6
+ import sys
7
+ import random
8
+
9
+ import numpy as np
10
+ import torch
11
+ import torch.backends.cudnn as cudnn
12
+ import gradio as gr
13
+
14
+ from global_local.common.config import Config
15
+ from global_local.common.dist_utils import get_rank
16
+ from global_local.common.registry import registry
17
+ from global_local.conversation.conversation_video import Chat, Conversation, default_conversation,SeparatorStyle,conv_llava_llama_2
18
+ import decord
19
+ decord.bridge.set_bridge('torch')
20
+
21
+ #%%
22
+ # imports modules for registration
23
+ from global_local.datasets.builders import *
24
+ from global_local.models import *
25
+ from global_local.processors import *
26
+ from global_local.runners import *
27
+ from global_local.tasks import *
28
+
29
+ #%%
30
+ def parse_args():
31
+ parser = argparse.ArgumentParser(description="Demo")
32
+ #parser.add_argument("--cfg-path", required=True, help="path to configuration file.")
33
+ parser.add_argument("--cfg-path", type=str, default='./eval_configs/conversation_demo.yaml', help="path to configuration file.")
34
+ parser.add_argument("--gpu-id", type=int, default=0, help="specify the gpu to load the model.")
35
+ parser.add_argument("--model_type", type=str, default='llama_v2', help="specify LLM")
36
+ parser.add_argument('--pretrained_weight_path', type=str, default="./ckpt/finetuned_model.pth", metavar='PWP',
37
+ help='path to pretrained weight path')
38
+ parser.add_argument('--num_frames_per_clip', type=int, default=16, metavar='NPPC',
39
+ help='specify how frames to use per clip')
40
+ parser.add_argument('--num_segments', type=int, default=4, metavar='NS',
41
+ help='specify number of video segments')
42
+ parser.add_argument('--hierarchical_agg_function', type=str, default="without-top-final-global-prompts-region-segment-full-dis-spatiotemporal-prompts-attn-early-attn-linear-learned", metavar='HAF',
43
+ help='specify function to merge global and clip visual representations')
44
+
45
+ parser.add_argument(
46
+ "--options",
47
+ nargs="+",
48
+ help="override some settings in the used config, the key-value pair "
49
+ "in xxx=yyy format will be merged into config file (deprecate), "
50
+ "change to --cfg-options instead.",
51
+ )
52
+ args = parser.parse_args()
53
+ return args
54
+
55
+
56
+ def setup_seeds(config):
57
+ seed = config.run_cfg.seed + get_rank()
58
+
59
+ random.seed(seed)
60
+ np.random.seed(seed)
61
+ torch.manual_seed(seed)
62
+
63
+ cudnn.benchmark = False
64
+ cudnn.deterministic = True
65
+
66
+
67
+ # ========================================
68
+ # Model Initialization
69
+ # ========================================
70
+
71
+ print('Initializing Chat')
72
+ args = parse_args()
73
+ cfg = Config(args)
74
+
75
+ model_config = cfg.model_cfg
76
+ model_config.device_8bit = args.gpu_id
77
+ model_cls = registry.get_model_class(model_config.arch)
78
+ model = model_cls.from_config(model_config).to('cuda:{}'.format(args.gpu_id))
79
+
80
+ model.num_frames_per_clip = args.num_frames_per_clip
81
+ model.num_segments = args.num_segments
82
+ model.hierarchical_agg_function = args.hierarchical_agg_function
83
+ model.global_region_embed_weight = None
84
+
85
+ model.initialize_visual_agg_function()
86
+
87
+ best_checkpoint = torch.load(args.pretrained_weight_path, map_location='cpu')['model_state_dict']
88
+ pretrained_dict = {}
89
+ for k, v in best_checkpoint.items():
90
+ pretrained_dict[k.replace('module.', '')] = v
91
+
92
+ model_dict = model.state_dict()
93
+ model_dict.update(pretrained_dict)
94
+ model.load_state_dict(model_dict)
95
+ model.cuda().eval()
96
+
97
+ #vis_processor_cfg = cfg.datasets_cfg.cc_sbu_align.vis_processor.train
98
+ vis_processor_cfg = cfg.datasets_cfg.webvid.vis_processor.train
99
+ vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg)
100
+ chat = Chat(model, vis_processor, device='cuda:{}'.format(args.gpu_id))
101
+ print('Initialization Finished')
102
+
103
+ # ========================================
104
+ # Gradio Setting
105
+ # ========================================
106
+
107
+ def gradio_reset(chat_state, img_list):
108
+ if chat_state is not None:
109
+ chat_state.messages = []
110
+ if img_list is not None:
111
+ img_list = []
112
+ return None, gr.update(value=None, interactive=True), gr.update(value=None, interactive=True), gr.update(placeholder='Please upload your video first', interactive=False),gr.update(value="Upload & Start Chat", interactive=True), chat_state, img_list
113
+
114
+ def upload_imgorvideo(gr_video, gr_img, text_input, chat_state,chatbot):
115
+ if args.model_type == 'vicuna':
116
+ chat_state = default_conversation.copy()
117
+ else:
118
+ chat_state = conv_llava_llama_2.copy()
119
+ if gr_img is None and gr_video is None:
120
+ return None, None, None, gr.update(interactive=True), chat_state, None
121
+ elif gr_img is not None and gr_video is None:
122
+ print(gr_img)
123
+ chatbot = chatbot + [((gr_img,), None)]
124
+ chat_state.system = "You are able to understand the visual content that the user provides. Follow the instructions carefully and explain your answers in detail."
125
+ img_list = []
126
+ llm_message = chat.upload_img(gr_img, chat_state, img_list)
127
+ return gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=True, placeholder='Type and press Enter'), gr.update(value="Start Chatting", interactive=False), chat_state, img_list,chatbot
128
+ elif gr_video is not None and gr_img is None:
129
+ print(gr_video)
130
+ chatbot = chatbot + [((gr_video,), None)]
131
+ chat_state.system = "You are able to understand the visual content that the user provides. Follow the instructions carefully and explain your answers in detail."
132
+ img_list = []
133
+ llm_message = chat.upload_video_without_audio(gr_video, chat_state, img_list)
134
+ return gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=True, placeholder='Type and press Enter'), gr.update(value="Start Chatting", interactive=False), chat_state, img_list,chatbot
135
+ else:
136
+ # img_list = []
137
+ return gr.update(interactive=False), gr.update(interactive=False, placeholder='Currently, only one input is supported'), gr.update(value="Currently, only one input is supported", interactive=False), chat_state, None,chatbot
138
+
139
+ def gradio_ask(user_message, chatbot, chat_state):
140
+ if len(user_message) == 0:
141
+ return gr.update(interactive=True, placeholder='Input should not be empty!'), chatbot, chat_state
142
+ chat.ask(user_message, chat_state)
143
+ chatbot = chatbot + [[user_message, None]]
144
+ return '', chatbot, chat_state
145
+
146
+
147
+ def gradio_answer(chatbot, chat_state, img_list, num_beams, temperature):
148
+ llm_message = chat.answer(conv=chat_state,
149
+ img_list=img_list,
150
+ num_beams=num_beams,
151
+ temperature=temperature,
152
+ max_new_tokens=300,
153
+ max_length=2000)[0]
154
+ chatbot[-1][1] = llm_message
155
+ print(chat_state.get_prompt())
156
+ print(chat_state)
157
+ return chatbot, chat_state, img_list
158
+
159
+ title = """
160
+ <h1 align="center">Global-Local QFormer for Long Video Understanding with LLMs</h1>
161
+
162
+ <h5 align="center"> Introduction: We introduce a Global-Local QFormer video model that is connected with a Large Language Model to understand \
163
+ and answer questions about long videos. </h5>
164
+
165
+ <div style='display:flex; gap: 0.25rem; '>
166
+ <a href='https://huggingface.co/spaces/rxtan/rxtan/Global-Local-QFormer-Video-LLM'><img src='https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue'></a>
167
+ <a href=''><img src='https://img.shields.io/badge/Paper-PDF-red'></a>
168
+ </div>
169
+
170
+
171
+ Thank you for using the Global-Local QFormer Demo Page! If you have any questions or feedback, feel free to contact us.
172
+ Current online demo uses the 7B version of Llama-2 due to resource limitations.
173
+
174
+
175
+ """
176
+
177
+ Note_markdown = ("""
178
+ ### We note that our Global-Local QFormer model may be limited at understanding videos from rare domains. Due to the pretraining data, the \
179
+ model may be susceptible to hallucinations
180
+ We would like to acknowledge the Video-LLama repository which we copied the demo layout from.
181
+
182
+ **Boston University**
183
+ """)
184
+
185
+ cite_markdown = ("""
186
+ """)
187
+
188
+ #case_note_upload = ("""
189
+ ### We provide some examples at the bottom of the page. Simply click on them to try them out directly.
190
+ #""")
191
+
192
+ #TODO show examples below
193
+
194
+ with gr.Blocks() as demo:
195
+ gr.Markdown(title)
196
+
197
+ with gr.Row():
198
+ with gr.Column(scale=0.5):
199
+ video = gr.Video()
200
+ #image = gr.Image(type="filepath")
201
+ #gr.Markdown(case_note_upload)
202
+
203
+ upload_button = gr.Button(value="Upload & Start Chat", interactive=True, variant="primary")
204
+ clear = gr.Button("Restart")
205
+
206
+ num_beams = gr.Slider(
207
+ minimum=1,
208
+ maximum=10,
209
+ value=1,
210
+ step=1,
211
+ interactive=True,
212
+ label="beam search numbers)",
213
+ )
214
+
215
+ temperature = gr.Slider(
216
+ minimum=0.1,
217
+ maximum=2.0,
218
+ value=1.0,
219
+ step=0.1,
220
+ interactive=True,
221
+ label="Temperature",
222
+ )
223
+
224
+ audio = gr.Checkbox(interactive=True, value=False, label="Audio")
225
+ gr.Markdown(Note_markdown)
226
+ with gr.Column():
227
+ chat_state = gr.State()
228
+ img_list = gr.State()
229
+ chatbot = gr.Chatbot(label='Global-Local QFormer')
230
+ text_input = gr.Textbox(label='User', placeholder='Please upload your video first.', interactive=False)
231
+
232
+
233
+ '''with gr.Column():
234
+ gr.Examples(examples=[
235
+ [f"examples/skateboarding_dog.mp4", "What is the dog doing? "],
236
+ [f"examples/birthday.mp4", "What is the boy doing? "],
237
+ [f"examples/IronMan.mp4", "Is the guy in the video Iron Man? "],
238
+ ], inputs=[video, text_input])'''
239
+
240
+ gr.Markdown(cite_markdown)
241
+ upload_button.click(upload_imgorvideo, [video, text_input, chat_state,chatbot], [video, text_input, upload_button, chat_state, img_list,chatbot])
242
+
243
+ text_input.submit(gradio_ask, [text_input, chatbot, chat_state], [text_input, chatbot, chat_state]).then(
244
+ gradio_answer, [chatbot, chat_state, img_list, num_beams, temperature], [chatbot, chat_state, img_list]
245
+ )
246
+ clear.click(gradio_reset, [chat_state, img_list], [chatbot, video, text_input, upload_button, chat_state, img_list], queue=False)
247
+
248
+ #demo.launch(share=False, enable_queue=True, debug=True)
249
+ demo.launch(share=False, debug=True)
ckpt/VL_LLaMA_2_7B_Finetuned.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3cec0e2979ed7656e08ecc5b185c2229a3c577b4b7a4721a94bd461ba0447c6e
3
+ size 265559201
ckpt/finetuned_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3795447e3459f467aae141873ba5f666efcc0f1478ddd9316437d3ba56aa72fd
3
+ size 38852011
demo_video.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Adapted from: https://github.com/Vision-CAIR/MiniGPT-4/blob/main/demo.py
3
+ """
4
+ import argparse
5
+ import os
6
+ import sys
7
+ import random
8
+
9
+ import numpy as np
10
+ import torch
11
+ import torch.backends.cudnn as cudnn
12
+ import gradio as gr
13
+
14
+ from global_local.common.config import Config
15
+ from global_local.common.dist_utils import get_rank
16
+ from global_local.common.registry import registry
17
+ from global_local.conversation.conversation_video import Chat, Conversation, default_conversation,SeparatorStyle,conv_llava_llama_2
18
+ import decord
19
+ decord.bridge.set_bridge('torch')
20
+
21
+ #%%
22
+ # imports modules for registration
23
+ from global_local.datasets.builders import *
24
+ from global_local.models import *
25
+ from global_local.processors import *
26
+ from global_local.runners import *
27
+ from global_local.tasks import *
28
+
29
+ #%%
30
+ def parse_args():
31
+ parser = argparse.ArgumentParser(description="Demo")
32
+ #parser.add_argument("--cfg-path", required=True, help="path to configuration file.")
33
+ parser.add_argument("--cfg-path", type=str, default='./eval_configs/conversation_demo.yaml', help="path to configuration file.")
34
+ parser.add_argument("--gpu-id", type=int, default=0, help="specify the gpu to load the model.")
35
+ parser.add_argument("--model_type", type=str, default='llama_v2', help="specify LLM")
36
+ parser.add_argument('--pretrained_weight_path', type=str, default="./ckpt/finetuned_model.pth", metavar='PWP',
37
+ help='path to pretrained weight path')
38
+ parser.add_argument('--num_frames_per_clip', type=int, default=16, metavar='NPPC',
39
+ help='specify how frames to use per clip')
40
+ parser.add_argument('--num_segments', type=int, default=4, metavar='NS',
41
+ help='specify number of video segments')
42
+ parser.add_argument('--hierarchical_agg_function', type=str, default="without-top-final-global-prompts-region-segment-full-dis-spatiotemporal-prompts-attn-early-attn-linear-learned", metavar='HAF',
43
+ help='specify function to merge global and clip visual representations')
44
+
45
+ parser.add_argument(
46
+ "--options",
47
+ nargs="+",
48
+ help="override some settings in the used config, the key-value pair "
49
+ "in xxx=yyy format will be merged into config file (deprecate), "
50
+ "change to --cfg-options instead.",
51
+ )
52
+ args = parser.parse_args()
53
+ return args
54
+
55
+
56
+ def setup_seeds(config):
57
+ seed = config.run_cfg.seed + get_rank()
58
+
59
+ random.seed(seed)
60
+ np.random.seed(seed)
61
+ torch.manual_seed(seed)
62
+
63
+ cudnn.benchmark = False
64
+ cudnn.deterministic = True
65
+
66
+
67
+ # ========================================
68
+ # Model Initialization
69
+ # ========================================
70
+
71
+ print('Initializing Chat')
72
+ args = parse_args()
73
+ cfg = Config(args)
74
+
75
+ model_config = cfg.model_cfg
76
+ model_config.device_8bit = args.gpu_id
77
+ model_cls = registry.get_model_class(model_config.arch)
78
+ model = model_cls.from_config(model_config).to('cuda:{}'.format(args.gpu_id))
79
+
80
+ model.num_frames_per_clip = args.num_frames_per_clip
81
+ model.num_segments = args.num_segments
82
+ model.hierarchical_agg_function = args.hierarchical_agg_function
83
+ model.global_region_embed_weight = None
84
+
85
+ model.initialize_visual_agg_function()
86
+
87
+ best_checkpoint = torch.load(args.pretrained_weight_path, map_location='cpu')['model_state_dict']
88
+ pretrained_dict = {}
89
+ for k, v in best_checkpoint.items():
90
+ pretrained_dict[k.replace('module.', '')] = v
91
+
92
+ model_dict = model.state_dict()
93
+ model_dict.update(pretrained_dict)
94
+ model.load_state_dict(model_dict)
95
+ model.cuda().eval()
96
+
97
+ #vis_processor_cfg = cfg.datasets_cfg.cc_sbu_align.vis_processor.train
98
+ vis_processor_cfg = cfg.datasets_cfg.webvid.vis_processor.train
99
+ vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg)
100
+ chat = Chat(model, vis_processor, device='cuda:{}'.format(args.gpu_id))
101
+ print('Initialization Finished')
102
+
103
+ # ========================================
104
+ # Gradio Setting
105
+ # ========================================
106
+
107
+ def gradio_reset(chat_state, img_list):
108
+ if chat_state is not None:
109
+ chat_state.messages = []
110
+ if img_list is not None:
111
+ img_list = []
112
+ return None, gr.update(value=None, interactive=True), gr.update(value=None, interactive=True), gr.update(placeholder='Please upload your video first', interactive=False),gr.update(value="Upload & Start Chat", interactive=True), chat_state, img_list
113
+
114
+ def upload_imgorvideo(gr_video, gr_img, text_input, chat_state,chatbot):
115
+ if args.model_type == 'vicuna':
116
+ chat_state = default_conversation.copy()
117
+ else:
118
+ chat_state = conv_llava_llama_2.copy()
119
+ if gr_img is None and gr_video is None:
120
+ return None, None, None, gr.update(interactive=True), chat_state, None
121
+ elif gr_img is not None and gr_video is None:
122
+ print(gr_img)
123
+ chatbot = chatbot + [((gr_img,), None)]
124
+ chat_state.system = "You are able to understand the visual content that the user provides. Follow the instructions carefully and explain your answers in detail."
125
+ img_list = []
126
+ llm_message = chat.upload_img(gr_img, chat_state, img_list)
127
+ return gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=True, placeholder='Type and press Enter'), gr.update(value="Start Chatting", interactive=False), chat_state, img_list,chatbot
128
+ elif gr_video is not None and gr_img is None:
129
+ print(gr_video)
130
+ chatbot = chatbot + [((gr_video,), None)]
131
+ chat_state.system = "You are able to understand the visual content that the user provides. Follow the instructions carefully and explain your answers in detail."
132
+ img_list = []
133
+ llm_message = chat.upload_video_without_audio(gr_video, chat_state, img_list)
134
+ return gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=True, placeholder='Type and press Enter'), gr.update(value="Start Chatting", interactive=False), chat_state, img_list,chatbot
135
+ else:
136
+ # img_list = []
137
+ return gr.update(interactive=False), gr.update(interactive=False, placeholder='Currently, only one input is supported'), gr.update(value="Currently, only one input is supported", interactive=False), chat_state, None,chatbot
138
+
139
+ def gradio_ask(user_message, chatbot, chat_state):
140
+ if len(user_message) == 0:
141
+ return gr.update(interactive=True, placeholder='Input should not be empty!'), chatbot, chat_state
142
+ chat.ask(user_message, chat_state)
143
+ chatbot = chatbot + [[user_message, None]]
144
+ return '', chatbot, chat_state
145
+
146
+
147
+ def gradio_answer(chatbot, chat_state, img_list, num_beams, temperature):
148
+ llm_message = chat.answer(conv=chat_state,
149
+ img_list=img_list,
150
+ num_beams=num_beams,
151
+ temperature=temperature,
152
+ max_new_tokens=300,
153
+ max_length=2000)[0]
154
+ chatbot[-1][1] = llm_message
155
+ print(chat_state.get_prompt())
156
+ print(chat_state)
157
+ return chatbot, chat_state, img_list
158
+
159
+ title = """
160
+ <h1 align="center">Global-Local QFormer for Long Video Understanding with LLMs</h1>
161
+
162
+ <h5 align="center"> Introduction: We introduce a Global-Local QFormer video model that is connected with a Large Language Model to understand \
163
+ and answer questions about long videos. </h5>
164
+
165
+ <div style='display:flex; gap: 0.25rem; '>
166
+ <a href='https://huggingface.co/spaces/rxtan/rxtan/Global-Local-QFormer-Video-LLM'><img src='https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue'></a>
167
+ <a href=''><img src='https://img.shields.io/badge/Paper-PDF-red'></a>
168
+ </div>
169
+
170
+
171
+ Thank you for using the Global-Local QFormer Demo Page! If you have any questions or feedback, feel free to contact us.
172
+ Current online demo uses the 7B version of Llama-2 due to resource limitations.
173
+
174
+
175
+ """
176
+
177
+ Note_markdown = ("""
178
+ ### We note that our Global-Local QFormer model may be limited at understanding videos from rare domains. Due to the pretraining data, the \
179
+ model may be susceptible to hallucinations
180
+ We would like to acknowledge the Video-LLama repository which we copied the demo layout from.
181
+
182
+ **Boston University**
183
+ """)
184
+
185
+ cite_markdown = ("""
186
+ """)
187
+
188
+ #case_note_upload = ("""
189
+ ### We provide some examples at the bottom of the page. Simply click on them to try them out directly.
190
+ #""")
191
+
192
+ #TODO show examples below
193
+
194
+ with gr.Blocks() as demo:
195
+ gr.Markdown(title)
196
+
197
+ with gr.Row():
198
+ with gr.Column(scale=0.5):
199
+ video = gr.Video()
200
+ #image = gr.Image(type="filepath")
201
+ #gr.Markdown(case_note_upload)
202
+
203
+ upload_button = gr.Button(value="Upload & Start Chat", interactive=True, variant="primary")
204
+ clear = gr.Button("Restart")
205
+
206
+ num_beams = gr.Slider(
207
+ minimum=1,
208
+ maximum=10,
209
+ value=1,
210
+ step=1,
211
+ interactive=True,
212
+ label="beam search numbers)",
213
+ )
214
+
215
+ temperature = gr.Slider(
216
+ minimum=0.1,
217
+ maximum=2.0,
218
+ value=1.0,
219
+ step=0.1,
220
+ interactive=True,
221
+ label="Temperature",
222
+ )
223
+
224
+ audio = gr.Checkbox(interactive=True, value=False, label="Audio")
225
+ gr.Markdown(Note_markdown)
226
+ with gr.Column():
227
+ chat_state = gr.State()
228
+ img_list = gr.State()
229
+ chatbot = gr.Chatbot(label='Global-Local QFormer')
230
+ text_input = gr.Textbox(label='User', placeholder='Please upload your video first.', interactive=False)
231
+
232
+
233
+ '''with gr.Column():
234
+ gr.Examples(examples=[
235
+ [f"examples/skateboarding_dog.mp4", "What is the dog doing? "],
236
+ [f"examples/birthday.mp4", "What is the boy doing? "],
237
+ [f"examples/IronMan.mp4", "Is the guy in the video Iron Man? "],
238
+ ], inputs=[video, text_input])'''
239
+
240
+ gr.Markdown(cite_markdown)
241
+ upload_button.click(upload_imgorvideo, [video, text_input, chat_state,chatbot], [video, text_input, upload_button, chat_state, img_list,chatbot])
242
+
243
+ text_input.submit(gradio_ask, [text_input, chatbot, chat_state], [text_input, chatbot, chat_state]).then(
244
+ gradio_answer, [chatbot, chat_state, img_list, num_beams, temperature], [chatbot, chat_state, img_list]
245
+ )
246
+ clear.click(gradio_reset, [chat_state, img_list], [chatbot, video, text_input, upload_button, chat_state, img_list], queue=False)
247
+
248
+ #demo.launch(share=False, enable_queue=True, debug=True)
249
+ demo.launch(share=False, debug=True)
eval_configs/conversation_demo.yaml ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ arch: video_instruction_llama
3
+ model_type: pretrain_vicuna
4
+ freeze_vit: True
5
+ freeze_qformer: True
6
+
7
+
8
+ # Q-Former
9
+ num_query_token: 32
10
+
11
+ # If you want train models based on LLaMA-2-chat,
12
+ # some ckpts could be download from our provided huggingface repo
13
+ # i.e. https://huggingface.co/DAMO-NLP-SG/Video-LLaMA-2-13B-Finetuned llama-2-7b-chat-hf
14
+ #llama_model: "/projectnb/ivc-ml/rxtan/llama-2-7b-chat-hf/"
15
+ llama_model: "Video-LLaMA-2-7B-Finetuned/llama-2-7b-chat-hf/"
16
+ imagebind_ckpt_path: "ckpt/imagebind_path/"
17
+
18
+ # The ckpt of vision branch after stage1 pretrained,
19
+ ckpt: 'ckpt/VL_LLaMA_2_7B_Finetuned.pth' # you can use our pretrained ckpt from https://huggingface.co/DAMO-NLP-SG/Video-LLaMA-2-13B-Pretrained/
20
+
21
+
22
+ # only train vision branch
23
+ equip_audio_branch: False # whether equips the audio branch
24
+ frozen_llama_proj: False
25
+ frozen_video_Qformer: True
26
+ frozen_audio_Qformer: True
27
+
28
+ fusion_head_layers: 2
29
+ max_frame_pos: 32
30
+ fusion_header_type: "seqTransf"
31
+
32
+ max_txt_len: 320
33
+
34
+ # for llama_2_chat:
35
+ end_sym: "</s>"
36
+ prompt_path: "prompts/alignment_image.txt"
37
+ prompt_template: '[INST] <<SYS>>\n \n<</SYS>>\n\n{} [/INST] '
38
+
39
+ datasets:
40
+ webvid:
41
+ vis_processor:
42
+ train:
43
+ name: "alpro_video_eval"
44
+ n_frms: 8
45
+ image_size: 224
46
+ text_processor:
47
+ train:
48
+ name: "blip_caption"
49
+
50
+ run:
51
+ task: video_text_pretrain
52
+ # optimizer
53
+ lr_sched: "linear_warmup_cosine_lr"
54
+ init_lr: 3e-5
55
+ min_lr: 1e-5
56
+ warmup_lr: 1e-6
57
+
58
+ weight_decay: 0.05
59
+ max_epoch: 3
60
+ iters_per_epoch: 1000
61
+ batch_size_train: 4
62
+ batch_size_eval: 4
63
+ num_workers: 4
64
+ warmup_steps: 1000
65
+
66
+ seed: 42
67
+ output_dir: "output/videollama_stage2_finetune"
68
+
69
+ amp: True
70
+ resume_ckpt_path: null
71
+
72
+ evaluate: False
73
+ train_splits: ["train"]
74
+
75
+ device: "cuda"
76
+ world_size: 1
77
+ dist_url: "env://"
78
+ distributed: True
global_local/__init__.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import os
9
+ import sys
10
+
11
+ from omegaconf import OmegaConf
12
+
13
+ from global_local.common.registry import registry
14
+
15
+ from global_local.datasets.builders import *
16
+ from global_local.models import *
17
+ from global_local.processors import *
18
+ from global_local.tasks import *
19
+
20
+
21
+ root_dir = os.path.dirname(os.path.abspath(__file__))
22
+ default_cfg = OmegaConf.load(os.path.join(root_dir, "configs/default.yaml"))
23
+
24
+ registry.register_path("library_root", root_dir)
25
+ repo_root = os.path.join(root_dir, "..")
26
+ registry.register_path("repo_root", repo_root)
27
+ cache_root = os.path.join(repo_root, default_cfg.env.cache_root)
28
+ registry.register_path("cache_root", cache_root)
29
+
30
+ registry.register("MAX_INT", sys.maxsize)
31
+ registry.register("SPLIT_NAMES", ["train", "val", "test"])
global_local/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (1.1 kB). View file
 
global_local/common/__init__.py ADDED
File without changes
global_local/common/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (219 Bytes). View file
 
global_local/common/__pycache__/config.cpython-39.pyc ADDED
Binary file (12.2 kB). View file
 
global_local/common/__pycache__/dist_utils.cpython-39.pyc ADDED
Binary file (4.55 kB). View file
 
global_local/common/__pycache__/logger.cpython-39.pyc ADDED
Binary file (6.48 kB). View file
 
global_local/common/__pycache__/optims.cpython-39.pyc ADDED
Binary file (3.89 kB). View file
 
global_local/common/__pycache__/registry.cpython-39.pyc ADDED
Binary file (9.11 kB). View file
 
global_local/common/__pycache__/utils.cpython-39.pyc ADDED
Binary file (12.7 kB). View file
 
global_local/common/config.py ADDED
@@ -0,0 +1,468 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import logging
9
+ import json
10
+ from typing import Dict
11
+
12
+ from omegaconf import OmegaConf
13
+ from global_local.common.registry import registry
14
+
15
+
16
+ class Config:
17
+ def __init__(self, args):
18
+ self.config = {}
19
+
20
+ self.args = args
21
+
22
+ # Register the config and configuration for setup
23
+ registry.register("configuration", self)
24
+
25
+ user_config = self._build_opt_list(self.args.options)
26
+
27
+ config = OmegaConf.load(self.args.cfg_path)
28
+
29
+ runner_config = self.build_runner_config(config)
30
+ model_config = self.build_model_config(config, **user_config)
31
+ dataset_config = self.build_dataset_config(config)
32
+
33
+ # Validate the user-provided runner configuration
34
+ # model and dataset configuration are supposed to be validated by the respective classes
35
+ # [TODO] validate the model/dataset configuration
36
+ # self._validate_runner_config(runner_config)
37
+
38
+ # Override the default configuration with user options.
39
+ self.config = OmegaConf.merge(
40
+ runner_config, model_config, dataset_config, user_config
41
+ )
42
+
43
+ def _validate_runner_config(self, runner_config):
44
+ """
45
+ This method validates the configuration, such that
46
+ 1) all the user specified options are valid;
47
+ 2) no type mismatches between the user specified options and the config.
48
+ """
49
+ runner_config_validator = create_runner_config_validator()
50
+ runner_config_validator.validate(runner_config)
51
+
52
+ def _build_opt_list(self, opts):
53
+ opts_dot_list = self._convert_to_dot_list(opts)
54
+ return OmegaConf.from_dotlist(opts_dot_list)
55
+
56
+ @staticmethod
57
+ def build_model_config(config, **kwargs):
58
+ model = config.get("model", None)
59
+ assert model is not None, "Missing model configuration file."
60
+
61
+ model_cls = registry.get_model_class(model.arch)
62
+ assert model_cls is not None, f"Model '{model.arch}' has not been registered."
63
+
64
+ model_type = kwargs.get("model.model_type", None)
65
+ if not model_type:
66
+ model_type = model.get("model_type", None)
67
+ # else use the model type selected by user.
68
+
69
+ assert model_type is not None, "Missing model_type."
70
+
71
+ model_config_path = model_cls.default_config_path(model_type=model_type)
72
+
73
+ model_config = OmegaConf.create()
74
+ # hierarchy override, customized config > default config
75
+ model_config = OmegaConf.merge(
76
+ model_config,
77
+ OmegaConf.load(model_config_path),
78
+ {"model": config["model"]},
79
+ )
80
+
81
+ return model_config
82
+
83
+ @staticmethod
84
+ def build_runner_config(config):
85
+ return {"run": config.run}
86
+
87
+ @staticmethod
88
+ def build_dataset_config(config):
89
+ datasets = config.get("datasets", None)
90
+ if datasets is None:
91
+ raise KeyError(
92
+ "Expecting 'datasets' as the root key for dataset configuration."
93
+ )
94
+
95
+ dataset_config = OmegaConf.create()
96
+
97
+ for dataset_name in datasets:
98
+ builder_cls = registry.get_builder_class(dataset_name)
99
+
100
+ dataset_config_type = datasets[dataset_name].get("type", "default")
101
+ dataset_config_path = builder_cls.default_config_path(
102
+ type=dataset_config_type
103
+ )
104
+
105
+ # hierarchy override, customized config > default config
106
+ dataset_config = OmegaConf.merge(
107
+ dataset_config,
108
+ OmegaConf.load(dataset_config_path),
109
+ {"datasets": {dataset_name: config["datasets"][dataset_name]}},
110
+ )
111
+
112
+ return dataset_config
113
+
114
+ def _convert_to_dot_list(self, opts):
115
+ if opts is None:
116
+ opts = []
117
+
118
+ if len(opts) == 0:
119
+ return opts
120
+
121
+ has_equal = opts[0].find("=") != -1
122
+
123
+ if has_equal:
124
+ return opts
125
+
126
+ return [(opt + "=" + value) for opt, value in zip(opts[0::2], opts[1::2])]
127
+
128
+ def get_config(self):
129
+ return self.config
130
+
131
+ @property
132
+ def run_cfg(self):
133
+ return self.config.run
134
+
135
+ @property
136
+ def datasets_cfg(self):
137
+ return self.config.datasets
138
+
139
+ @property
140
+ def model_cfg(self):
141
+ return self.config.model
142
+
143
+ def pretty_print(self):
144
+ logging.info("\n===== Running Parameters =====")
145
+ logging.info(self._convert_node_to_json(self.config.run))
146
+
147
+ logging.info("\n====== Dataset Attributes ======")
148
+ datasets = self.config.datasets
149
+
150
+ for dataset in datasets:
151
+ if dataset in self.config.datasets:
152
+ logging.info(f"\n======== {dataset} =======")
153
+ dataset_config = self.config.datasets[dataset]
154
+ logging.info(self._convert_node_to_json(dataset_config))
155
+ else:
156
+ logging.warning(f"No dataset named '{dataset}' in config. Skipping")
157
+
158
+ logging.info(f"\n====== Model Attributes ======")
159
+ logging.info(self._convert_node_to_json(self.config.model))
160
+
161
+ def _convert_node_to_json(self, node):
162
+ container = OmegaConf.to_container(node, resolve=True)
163
+ return json.dumps(container, indent=4, sort_keys=True)
164
+
165
+ def to_dict(self):
166
+ return OmegaConf.to_container(self.config)
167
+
168
+
169
+ def node_to_dict(node):
170
+ return OmegaConf.to_container(node)
171
+
172
+
173
+ class ConfigValidator:
174
+ """
175
+ This is a preliminary implementation to centralize and validate the configuration.
176
+ May be altered in the future.
177
+
178
+ A helper class to validate configurations from yaml file.
179
+
180
+ This serves the following purposes:
181
+ 1. Ensure all the options in the yaml are defined, raise error if not.
182
+ 2. when type mismatches are found, the validator will raise an error.
183
+ 3. a central place to store and display helpful messages for supported configurations.
184
+
185
+ """
186
+
187
+ class _Argument:
188
+ def __init__(self, name, choices=None, type=None, help=None):
189
+ self.name = name
190
+ self.val = None
191
+ self.choices = choices
192
+ self.type = type
193
+ self.help = help
194
+
195
+ def __str__(self):
196
+ s = f"{self.name}={self.val}"
197
+ if self.type is not None:
198
+ s += f", ({self.type})"
199
+ if self.choices is not None:
200
+ s += f", choices: {self.choices}"
201
+ if self.help is not None:
202
+ s += f", ({self.help})"
203
+ return s
204
+
205
+ def __init__(self, description):
206
+ self.description = description
207
+
208
+ self.arguments = dict()
209
+
210
+ self.parsed_args = None
211
+
212
+ def __getitem__(self, key):
213
+ assert self.parsed_args is not None, "No arguments parsed yet."
214
+
215
+ return self.parsed_args[key]
216
+
217
+ def __str__(self) -> str:
218
+ return self.format_help()
219
+
220
+ def add_argument(self, *args, **kwargs):
221
+ """
222
+ Assume the first argument is the name of the argument.
223
+ """
224
+ self.arguments[args[0]] = self._Argument(*args, **kwargs)
225
+
226
+ def validate(self, config=None):
227
+ """
228
+ Convert yaml config (dict-like) to list, required by argparse.
229
+ """
230
+ for k, v in config.items():
231
+ assert (
232
+ k in self.arguments
233
+ ), f"""{k} is not a valid argument. Support arguments are {self.format_arguments()}."""
234
+
235
+ if self.arguments[k].type is not None:
236
+ try:
237
+ self.arguments[k].val = self.arguments[k].type(v)
238
+ except ValueError:
239
+ raise ValueError(f"{k} is not a valid {self.arguments[k].type}.")
240
+
241
+ if self.arguments[k].choices is not None:
242
+ assert (
243
+ v in self.arguments[k].choices
244
+ ), f"""{k} must be one of {self.arguments[k].choices}."""
245
+
246
+ return config
247
+
248
+ def format_arguments(self):
249
+ return str([f"{k}" for k in sorted(self.arguments.keys())])
250
+
251
+ def format_help(self):
252
+ # description + key-value pair string for each argument
253
+ help_msg = str(self.description)
254
+ return help_msg + ", available arguments: " + self.format_arguments()
255
+
256
+ def print_help(self):
257
+ # display help message
258
+ print(self.format_help())
259
+
260
+
261
+ def create_runner_config_validator():
262
+ validator = ConfigValidator(description="Runner configurations")
263
+
264
+ validator.add_argument(
265
+ "runner",
266
+ type=str,
267
+ choices=["runner_base", "runner_iter"],
268
+ help="""Runner to use. The "runner_base" uses epoch-based training while iter-based
269
+ runner runs based on iters. Default: runner_base""",
270
+ )
271
+ # add argumetns for training dataset ratios
272
+ validator.add_argument(
273
+ "train_dataset_ratios",
274
+ type=Dict[str, float],
275
+ help="""Ratios of training dataset. This is used in iteration-based runner.
276
+ Do not support for epoch-based runner because how to define an epoch becomes tricky.
277
+ Default: None""",
278
+ )
279
+ validator.add_argument(
280
+ "max_iters",
281
+ type=float,
282
+ help="Maximum number of iterations to run.",
283
+ )
284
+ validator.add_argument(
285
+ "max_epoch",
286
+ type=int,
287
+ help="Maximum number of epochs to run.",
288
+ )
289
+ # add arguments for iters_per_inner_epoch
290
+ validator.add_argument(
291
+ "iters_per_inner_epoch",
292
+ type=float,
293
+ help="Number of iterations per inner epoch. This is required when runner is runner_iter.",
294
+ )
295
+ lr_scheds_choices = registry.list_lr_schedulers()
296
+ validator.add_argument(
297
+ "lr_sched",
298
+ type=str,
299
+ choices=lr_scheds_choices,
300
+ help="Learning rate scheduler to use, from {}".format(lr_scheds_choices),
301
+ )
302
+ task_choices = registry.list_tasks()
303
+ validator.add_argument(
304
+ "task",
305
+ type=str,
306
+ choices=task_choices,
307
+ help="Task to use, from {}".format(task_choices),
308
+ )
309
+ # add arguments for init_lr
310
+ validator.add_argument(
311
+ "init_lr",
312
+ type=float,
313
+ help="Initial learning rate. This will be the learning rate after warmup and before decay.",
314
+ )
315
+ # add arguments for min_lr
316
+ validator.add_argument(
317
+ "min_lr",
318
+ type=float,
319
+ help="Minimum learning rate (after decay).",
320
+ )
321
+ # add arguments for warmup_lr
322
+ validator.add_argument(
323
+ "warmup_lr",
324
+ type=float,
325
+ help="Starting learning rate for warmup.",
326
+ )
327
+ # add arguments for learning rate decay rate
328
+ validator.add_argument(
329
+ "lr_decay_rate",
330
+ type=float,
331
+ help="Learning rate decay rate. Required if using a decaying learning rate scheduler.",
332
+ )
333
+ # add arguments for weight decay
334
+ validator.add_argument(
335
+ "weight_decay",
336
+ type=float,
337
+ help="Weight decay rate.",
338
+ )
339
+ # add arguments for training batch size
340
+ validator.add_argument(
341
+ "batch_size_train",
342
+ type=int,
343
+ help="Training batch size.",
344
+ )
345
+ # add arguments for evaluation batch size
346
+ validator.add_argument(
347
+ "batch_size_eval",
348
+ type=int,
349
+ help="Evaluation batch size, including validation and testing.",
350
+ )
351
+ # add arguments for number of workers for data loading
352
+ validator.add_argument(
353
+ "num_workers",
354
+ help="Number of workers for data loading.",
355
+ )
356
+ # add arguments for warm up steps
357
+ validator.add_argument(
358
+ "warmup_steps",
359
+ type=int,
360
+ help="Number of warmup steps. Required if a warmup schedule is used.",
361
+ )
362
+ # add arguments for random seed
363
+ validator.add_argument(
364
+ "seed",
365
+ type=int,
366
+ help="Random seed.",
367
+ )
368
+ # add arguments for output directory
369
+ validator.add_argument(
370
+ "output_dir",
371
+ type=str,
372
+ help="Output directory to save checkpoints and logs.",
373
+ )
374
+ # add arguments for whether only use evaluation
375
+ validator.add_argument(
376
+ "evaluate",
377
+ help="Whether to only evaluate the model. If true, training will not be performed.",
378
+ )
379
+ # add arguments for splits used for training, e.g. ["train", "val"]
380
+ validator.add_argument(
381
+ "train_splits",
382
+ type=list,
383
+ help="Splits to use for training.",
384
+ )
385
+ # add arguments for splits used for validation, e.g. ["val"]
386
+ validator.add_argument(
387
+ "valid_splits",
388
+ type=list,
389
+ help="Splits to use for validation. If not provided, will skip the validation.",
390
+ )
391
+ # add arguments for splits used for testing, e.g. ["test"]
392
+ validator.add_argument(
393
+ "test_splits",
394
+ type=list,
395
+ help="Splits to use for testing. If not provided, will skip the testing.",
396
+ )
397
+ # add arguments for accumulating gradient for iterations
398
+ validator.add_argument(
399
+ "accum_grad_iters",
400
+ type=int,
401
+ help="Number of iterations to accumulate gradient for.",
402
+ )
403
+
404
+ # ====== distributed training ======
405
+ validator.add_argument(
406
+ "device",
407
+ type=str,
408
+ choices=["cpu", "cuda"],
409
+ help="Device to use. Support 'cuda' or 'cpu' as for now.",
410
+ )
411
+ validator.add_argument(
412
+ "world_size",
413
+ type=int,
414
+ help="Number of processes participating in the job.",
415
+ )
416
+ validator.add_argument("dist_url", type=str)
417
+ validator.add_argument("distributed", type=bool)
418
+ # add arguments to opt using distributed sampler during evaluation or not
419
+ validator.add_argument(
420
+ "use_dist_eval_sampler",
421
+ type=bool,
422
+ help="Whether to use distributed sampler during evaluation or not.",
423
+ )
424
+
425
+ # ====== task specific ======
426
+ # generation task specific arguments
427
+ # add arguments for maximal length of text output
428
+ validator.add_argument(
429
+ "max_len",
430
+ type=int,
431
+ help="Maximal length of text output.",
432
+ )
433
+ # add arguments for minimal length of text output
434
+ validator.add_argument(
435
+ "min_len",
436
+ type=int,
437
+ help="Minimal length of text output.",
438
+ )
439
+ # add arguments number of beams
440
+ validator.add_argument(
441
+ "num_beams",
442
+ type=int,
443
+ help="Number of beams used for beam search.",
444
+ )
445
+
446
+ # vqa task specific arguments
447
+ # add arguments for number of answer candidates
448
+ validator.add_argument(
449
+ "num_ans_candidates",
450
+ type=int,
451
+ help="""For ALBEF and BLIP, these models first rank answers according to likelihood to select answer candidates.""",
452
+ )
453
+ # add arguments for inference method
454
+ validator.add_argument(
455
+ "inference_method",
456
+ type=str,
457
+ choices=["genearte", "rank"],
458
+ help="""Inference method to use for question answering. If rank, requires a answer list.""",
459
+ )
460
+
461
+ # ====== model specific ======
462
+ validator.add_argument(
463
+ "k_test",
464
+ type=int,
465
+ help="Number of top k most similar samples from ITC/VTC selection to be tested.",
466
+ )
467
+
468
+ return validator
global_local/common/dist_utils.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import datetime
9
+ import functools
10
+ import os
11
+
12
+ import torch
13
+ import torch.distributed as dist
14
+ import timm.models.hub as timm_hub
15
+
16
+
17
+ def setup_for_distributed(is_master):
18
+ """
19
+ This function disables printing when not in master process
20
+ """
21
+ import builtins as __builtin__
22
+
23
+ builtin_print = __builtin__.print
24
+
25
+ def print(*args, **kwargs):
26
+ force = kwargs.pop("force", False)
27
+ if is_master or force:
28
+ builtin_print(*args, **kwargs)
29
+
30
+ __builtin__.print = print
31
+
32
+
33
+ def is_dist_avail_and_initialized():
34
+ if not dist.is_available():
35
+ return False
36
+ if not dist.is_initialized():
37
+ return False
38
+ return True
39
+
40
+
41
+ def get_world_size():
42
+ if not is_dist_avail_and_initialized():
43
+ return 1
44
+ return dist.get_world_size()
45
+
46
+
47
+ def get_rank():
48
+ if not is_dist_avail_and_initialized():
49
+ return 0
50
+ return dist.get_rank()
51
+
52
+
53
+ def is_main_process():
54
+ return get_rank() == 0
55
+
56
+
57
+ def init_distributed_mode(args):
58
+ if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
59
+ args.rank = int(os.environ["RANK"])
60
+ args.world_size = int(os.environ["WORLD_SIZE"])
61
+ args.gpu = int(os.environ["LOCAL_RANK"])
62
+ elif "SLURM_PROCID" in os.environ:
63
+ args.rank = int(os.environ["SLURM_PROCID"])
64
+ args.gpu = args.rank % torch.cuda.device_count()
65
+ else:
66
+ print("Not using distributed mode")
67
+ args.distributed = False
68
+ return
69
+
70
+ args.distributed = True
71
+
72
+ torch.cuda.set_device(args.gpu)
73
+ args.dist_backend = "nccl"
74
+ print(
75
+ "| distributed init (rank {}, world {}): {}".format(
76
+ args.rank, args.world_size, args.dist_url
77
+ ),
78
+ flush=True,
79
+ )
80
+ torch.distributed.init_process_group(
81
+ backend=args.dist_backend,
82
+ init_method=args.dist_url,
83
+ world_size=args.world_size,
84
+ rank=args.rank,
85
+ timeout=datetime.timedelta(
86
+ days=365
87
+ ), # allow auto-downloading and de-compressing
88
+ )
89
+ torch.distributed.barrier()
90
+ setup_for_distributed(args.rank == 0)
91
+
92
+
93
+ def get_dist_info():
94
+ if torch.__version__ < "1.0":
95
+ initialized = dist._initialized
96
+ else:
97
+ initialized = dist.is_initialized()
98
+ if initialized:
99
+ rank = dist.get_rank()
100
+ world_size = dist.get_world_size()
101
+ else: # non-distributed training
102
+ rank = 0
103
+ world_size = 1
104
+ return rank, world_size
105
+
106
+ def no_grad_all_gather(tensors):
107
+ """
108
+ All gathers the provided tensors from all processes across machines.
109
+ Args:
110
+ tensors (list): tensors to perform all gather across all processes in
111
+ all machines.
112
+ """
113
+
114
+ gather_list = []
115
+ output_tensor = []
116
+ world_size = dist.get_world_size()
117
+
118
+ for tensor in tensors:
119
+ tensor_placeholder = [torch.ones_like(tensor) for _ in range(world_size)]
120
+ dist.all_gather(tensor_placeholder, tensor, async_op=False)
121
+ gather_list.append(tensor_placeholder)
122
+ for gathered_tensor in gather_list:
123
+ output_tensor.append(torch.cat(gathered_tensor, dim=0))
124
+ return output_tensor
125
+
126
+ def main_process(func):
127
+ @functools.wraps(func)
128
+ def wrapper(*args, **kwargs):
129
+ rank, _ = get_dist_info()
130
+ if rank == 0:
131
+ return func(*args, **kwargs)
132
+
133
+ return wrapper
134
+
135
+
136
+ def download_cached_file(url, check_hash=True, progress=False):
137
+ """
138
+ Download a file from a URL and cache it locally. If the file already exists, it is not downloaded again.
139
+ If distributed, only the main process downloads the file, and the other processes wait for the file to be downloaded.
140
+ """
141
+
142
+ def get_cached_file_path():
143
+ # a hack to sync the file path across processes
144
+ parts = torch.hub.urlparse(url)
145
+ filename = os.path.basename(parts.path)
146
+ cached_file = os.path.join(timm_hub.get_cache_dir(), filename)
147
+
148
+ return cached_file
149
+
150
+ if is_main_process():
151
+ timm_hub.download_cached_file(url, check_hash, progress)
152
+
153
+ #if is_dist_avail_and_initialized():
154
+ # dist.barrier()
155
+
156
+ return get_cached_file_path()
global_local/common/gradcam.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from matplotlib import pyplot as plt
3
+ from scipy.ndimage import filters
4
+ from skimage import transform as skimage_transform
5
+
6
+
7
+ def getAttMap(img, attMap, blur=True, overlap=True):
8
+ attMap -= attMap.min()
9
+ if attMap.max() > 0:
10
+ attMap /= attMap.max()
11
+ attMap = skimage_transform.resize(attMap, (img.shape[:2]), order=3, mode="constant")
12
+ if blur:
13
+ attMap = filters.gaussian_filter(attMap, 0.02 * max(img.shape[:2]))
14
+ attMap -= attMap.min()
15
+ attMap /= attMap.max()
16
+ cmap = plt.get_cmap("jet")
17
+ attMapV = cmap(attMap)
18
+ attMapV = np.delete(attMapV, 3, 2)
19
+ if overlap:
20
+ attMap = (
21
+ 1 * (1 - attMap**0.7).reshape(attMap.shape + (1,)) * img
22
+ + (attMap**0.7).reshape(attMap.shape + (1,)) * attMapV
23
+ )
24
+ return attMap
global_local/common/logger.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import datetime
9
+ import logging
10
+ import time
11
+ from collections import defaultdict, deque
12
+
13
+ import torch
14
+ import torch.distributed as dist
15
+
16
+ from global_local.common import dist_utils
17
+
18
+
19
+ class SmoothedValue(object):
20
+ """Track a series of values and provide access to smoothed values over a
21
+ window or the global series average.
22
+ """
23
+
24
+ def __init__(self, window_size=20, fmt=None):
25
+ if fmt is None:
26
+ fmt = "{median:.4f} ({global_avg:.4f})"
27
+ self.deque = deque(maxlen=window_size)
28
+ self.total = 0.0
29
+ self.count = 0
30
+ self.fmt = fmt
31
+
32
+ def update(self, value, n=1):
33
+ self.deque.append(value)
34
+ self.count += n
35
+ self.total += value * n
36
+
37
+ def synchronize_between_processes(self):
38
+ """
39
+ Warning: does not synchronize the deque!
40
+ """
41
+ if not dist_utils.is_dist_avail_and_initialized():
42
+ return
43
+ t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda")
44
+ dist.barrier()
45
+ dist.all_reduce(t)
46
+ t = t.tolist()
47
+ self.count = int(t[0])
48
+ self.total = t[1]
49
+
50
+ @property
51
+ def median(self):
52
+ d = torch.tensor(list(self.deque))
53
+ return d.median().item()
54
+
55
+ @property
56
+ def avg(self):
57
+ d = torch.tensor(list(self.deque), dtype=torch.float32)
58
+ return d.mean().item()
59
+
60
+ @property
61
+ def global_avg(self):
62
+ return self.total / self.count
63
+
64
+ @property
65
+ def max(self):
66
+ return max(self.deque)
67
+
68
+ @property
69
+ def value(self):
70
+ return self.deque[-1]
71
+
72
+ def __str__(self):
73
+ return self.fmt.format(
74
+ median=self.median,
75
+ avg=self.avg,
76
+ global_avg=self.global_avg,
77
+ max=self.max,
78
+ value=self.value,
79
+ )
80
+
81
+
82
+ class MetricLogger(object):
83
+ def __init__(self, delimiter="\t"):
84
+ self.meters = defaultdict(SmoothedValue)
85
+ self.delimiter = delimiter
86
+
87
+ def update(self, **kwargs):
88
+ for k, v in kwargs.items():
89
+ if isinstance(v, torch.Tensor):
90
+ v = v.item()
91
+ assert isinstance(v, (float, int))
92
+ self.meters[k].update(v)
93
+
94
+ def __getattr__(self, attr):
95
+ if attr in self.meters:
96
+ return self.meters[attr]
97
+ if attr in self.__dict__:
98
+ return self.__dict__[attr]
99
+ raise AttributeError(
100
+ "'{}' object has no attribute '{}'".format(type(self).__name__, attr)
101
+ )
102
+
103
+ def __str__(self):
104
+ loss_str = []
105
+ for name, meter in self.meters.items():
106
+ loss_str.append("{}: {}".format(name, str(meter)))
107
+ return self.delimiter.join(loss_str)
108
+
109
+ def global_avg(self):
110
+ loss_str = []
111
+ for name, meter in self.meters.items():
112
+ loss_str.append("{}: {:.4f}".format(name, meter.global_avg))
113
+ return self.delimiter.join(loss_str)
114
+
115
+ def synchronize_between_processes(self):
116
+ for meter in self.meters.values():
117
+ meter.synchronize_between_processes()
118
+
119
+ def add_meter(self, name, meter):
120
+ self.meters[name] = meter
121
+
122
+ def log_every(self, iterable, print_freq, header=None):
123
+ i = 0
124
+ if not header:
125
+ header = ""
126
+ start_time = time.time()
127
+ end = time.time()
128
+ iter_time = SmoothedValue(fmt="{avg:.4f}")
129
+ data_time = SmoothedValue(fmt="{avg:.4f}")
130
+ space_fmt = ":" + str(len(str(len(iterable)))) + "d"
131
+ log_msg = [
132
+ header,
133
+ "[{0" + space_fmt + "}/{1}]",
134
+ "eta: {eta}",
135
+ "{meters}",
136
+ "time: {time}",
137
+ "data: {data}",
138
+ ]
139
+ if torch.cuda.is_available():
140
+ log_msg.append("max mem: {memory:.0f}")
141
+ log_msg = self.delimiter.join(log_msg)
142
+ MB = 1024.0 * 1024.0
143
+ for obj in iterable:
144
+ data_time.update(time.time() - end)
145
+ yield obj
146
+ iter_time.update(time.time() - end)
147
+ if i % print_freq == 0 or i == len(iterable) - 1:
148
+ eta_seconds = iter_time.global_avg * (len(iterable) - i)
149
+ eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
150
+ if torch.cuda.is_available():
151
+ print(
152
+ log_msg.format(
153
+ i,
154
+ len(iterable),
155
+ eta=eta_string,
156
+ meters=str(self),
157
+ time=str(iter_time),
158
+ data=str(data_time),
159
+ memory=torch.cuda.max_memory_allocated() / MB,
160
+ )
161
+ )
162
+ else:
163
+ print(
164
+ log_msg.format(
165
+ i,
166
+ len(iterable),
167
+ eta=eta_string,
168
+ meters=str(self),
169
+ time=str(iter_time),
170
+ data=str(data_time),
171
+ )
172
+ )
173
+ i += 1
174
+ end = time.time()
175
+ total_time = time.time() - start_time
176
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
177
+ print(
178
+ "{} Total time: {} ({:.4f} s / it)".format(
179
+ header, total_time_str, total_time / len(iterable)
180
+ )
181
+ )
182
+
183
+
184
+ class AttrDict(dict):
185
+ def __init__(self, *args, **kwargs):
186
+ super(AttrDict, self).__init__(*args, **kwargs)
187
+ self.__dict__ = self
188
+
189
+
190
+ def setup_logger():
191
+ logging.basicConfig(
192
+ level=logging.INFO if dist_utils.is_main_process() else logging.WARN,
193
+ format="%(asctime)s [%(levelname)s] %(message)s",
194
+ handlers=[logging.StreamHandler()],
195
+ )
global_local/common/optims.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import math
9
+
10
+ from global_local.common.registry import registry
11
+ from torch.optim.lr_scheduler import LambdaLR
12
+
13
+
14
+ @registry.register_lr_scheduler("linear_warmup_step_lr")
15
+ class LinearWarmupStepLRScheduler:
16
+ def __init__(
17
+ self,
18
+ optimizer,
19
+ max_epoch,
20
+ min_lr,
21
+ init_lr,
22
+ decay_rate=1,
23
+ warmup_start_lr=-1,
24
+ warmup_steps=0,
25
+ **kwargs
26
+ ):
27
+ self.optimizer = optimizer
28
+
29
+ self.max_epoch = max_epoch
30
+ self.min_lr = min_lr
31
+
32
+ self.decay_rate = decay_rate
33
+
34
+ self.init_lr = init_lr
35
+ self.warmup_steps = warmup_steps
36
+ self.warmup_start_lr = warmup_start_lr if warmup_start_lr >= 0 else init_lr
37
+
38
+ def step(self, cur_epoch, cur_step):
39
+ if cur_epoch == 0:
40
+ warmup_lr_schedule(
41
+ step=cur_step,
42
+ optimizer=self.optimizer,
43
+ max_step=self.warmup_steps,
44
+ init_lr=self.warmup_start_lr,
45
+ max_lr=self.init_lr,
46
+ )
47
+ else:
48
+ step_lr_schedule(
49
+ epoch=cur_epoch,
50
+ optimizer=self.optimizer,
51
+ init_lr=self.init_lr,
52
+ min_lr=self.min_lr,
53
+ decay_rate=self.decay_rate,
54
+ )
55
+
56
+
57
+ @registry.register_lr_scheduler("linear_warmup_cosine_lr")
58
+ class LinearWarmupCosineLRScheduler:
59
+ def __init__(
60
+ self,
61
+ optimizer,
62
+ max_epoch,
63
+ iters_per_epoch,
64
+ min_lr,
65
+ init_lr,
66
+ warmup_steps=0,
67
+ warmup_start_lr=-1,
68
+ **kwargs
69
+ ):
70
+ self.optimizer = optimizer
71
+
72
+ self.max_epoch = max_epoch
73
+ self.iters_per_epoch = iters_per_epoch
74
+ self.min_lr = min_lr
75
+
76
+ self.init_lr = init_lr
77
+ self.warmup_steps = warmup_steps
78
+ self.warmup_start_lr = warmup_start_lr if warmup_start_lr >= 0 else init_lr
79
+
80
+ def step(self, cur_epoch, cur_step):
81
+ total_cur_step = cur_epoch * self.iters_per_epoch + cur_step
82
+ if total_cur_step < self.warmup_steps:
83
+ warmup_lr_schedule(
84
+ step=cur_step,
85
+ optimizer=self.optimizer,
86
+ max_step=self.warmup_steps,
87
+ init_lr=self.warmup_start_lr,
88
+ max_lr=self.init_lr,
89
+ )
90
+ else:
91
+ cosine_lr_schedule(
92
+ epoch=total_cur_step,
93
+ optimizer=self.optimizer,
94
+ max_epoch=self.max_epoch * self.iters_per_epoch,
95
+ init_lr=self.init_lr,
96
+ min_lr=self.min_lr,
97
+ )
98
+
99
+
100
+ def cosine_lr_schedule(optimizer, epoch, max_epoch, init_lr, min_lr):
101
+ """Decay the learning rate"""
102
+ lr = (init_lr - min_lr) * 0.5 * (
103
+ 1.0 + math.cos(math.pi * epoch / max_epoch)
104
+ ) + min_lr
105
+ for param_group in optimizer.param_groups:
106
+ param_group["lr"] = lr
107
+
108
+
109
+ def warmup_lr_schedule(optimizer, step, max_step, init_lr, max_lr):
110
+ """Warmup the learning rate"""
111
+ lr = min(max_lr, init_lr + (max_lr - init_lr) * step / max(max_step, 1))
112
+ for param_group in optimizer.param_groups:
113
+ param_group["lr"] = lr
114
+
115
+
116
+ def step_lr_schedule(optimizer, epoch, init_lr, min_lr, decay_rate):
117
+ """Decay the learning rate"""
118
+ lr = max(min_lr, init_lr * (decay_rate**epoch))
119
+ for param_group in optimizer.param_groups:
120
+ param_group["lr"] = lr
121
+
122
+ def get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, num_cycles=0.5, last_epoch=-1):
123
+ """ Create a schedule with a learning rate that decreases following the
124
+ values of the cosine function between 0 and `pi * cycles` after a warmup
125
+ period during which it increases linearly between 0 and 1.
126
+ """
127
+
128
+ def lr_lambda(current_step):
129
+ if current_step < num_warmup_steps:
130
+ return float(current_step) / float(max(1, num_warmup_steps))
131
+ progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
132
+ return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))
133
+
134
+ return LambdaLR(optimizer, lr_lambda, last_epoch)
global_local/common/registry.py ADDED
@@ -0,0 +1,329 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+
9
+ class Registry:
10
+ mapping = {
11
+ "builder_name_mapping": {},
12
+ "task_name_mapping": {},
13
+ "processor_name_mapping": {},
14
+ "model_name_mapping": {},
15
+ "lr_scheduler_name_mapping": {},
16
+ "runner_name_mapping": {},
17
+ "state": {},
18
+ "paths": {},
19
+ }
20
+
21
+ @classmethod
22
+ def register_builder(cls, name):
23
+ r"""Register a dataset builder to registry with key 'name'
24
+
25
+ Args:
26
+ name: Key with which the builder will be registered.
27
+
28
+ Usage:
29
+
30
+ from video_llama.common.registry import registry
31
+ from video_llama.datasets.base_dataset_builder import BaseDatasetBuilder
32
+ """
33
+
34
+ def wrap(builder_cls):
35
+ from global_local.datasets.builders.base_dataset_builder import BaseDatasetBuilder
36
+
37
+ assert issubclass(
38
+ builder_cls, BaseDatasetBuilder
39
+ ), "All builders must inherit BaseDatasetBuilder class, found {}".format(
40
+ builder_cls
41
+ )
42
+ if name in cls.mapping["builder_name_mapping"]:
43
+ raise KeyError(
44
+ "Name '{}' already registered for {}.".format(
45
+ name, cls.mapping["builder_name_mapping"][name]
46
+ )
47
+ )
48
+ cls.mapping["builder_name_mapping"][name] = builder_cls
49
+ return builder_cls
50
+
51
+ return wrap
52
+
53
+ @classmethod
54
+ def register_task(cls, name):
55
+ r"""Register a task to registry with key 'name'
56
+
57
+ Args:
58
+ name: Key with which the task will be registered.
59
+
60
+ Usage:
61
+
62
+ from video_llama.common.registry import registry
63
+ """
64
+
65
+ def wrap(task_cls):
66
+ from global_local.tasks.base_task import BaseTask
67
+
68
+ assert issubclass(
69
+ task_cls, BaseTask
70
+ ), "All tasks must inherit BaseTask class"
71
+ if name in cls.mapping["task_name_mapping"]:
72
+ raise KeyError(
73
+ "Name '{}' already registered for {}.".format(
74
+ name, cls.mapping["task_name_mapping"][name]
75
+ )
76
+ )
77
+ cls.mapping["task_name_mapping"][name] = task_cls
78
+ return task_cls
79
+
80
+ return wrap
81
+
82
+ @classmethod
83
+ def register_model(cls, name):
84
+ r"""Register a task to registry with key 'name'
85
+
86
+ Args:
87
+ name: Key with which the task will be registered.
88
+
89
+ Usage:
90
+
91
+ from video_llama.common.registry import registry
92
+ """
93
+
94
+ def wrap(model_cls):
95
+ from global_local.models import BaseModel
96
+
97
+ assert issubclass(
98
+ model_cls, BaseModel
99
+ ), "All models must inherit BaseModel class"
100
+ if name in cls.mapping["model_name_mapping"]:
101
+ raise KeyError(
102
+ "Name '{}' already registered for {}.".format(
103
+ name, cls.mapping["model_name_mapping"][name]
104
+ )
105
+ )
106
+ cls.mapping["model_name_mapping"][name] = model_cls
107
+ return model_cls
108
+
109
+ return wrap
110
+
111
+ @classmethod
112
+ def register_processor(cls, name):
113
+ r"""Register a processor to registry with key 'name'
114
+
115
+ Args:
116
+ name: Key with which the task will be registered.
117
+
118
+ Usage:
119
+
120
+ from video_llama.common.registry import registry
121
+ """
122
+
123
+ def wrap(processor_cls):
124
+ from global_local.processors import BaseProcessor
125
+
126
+ assert issubclass(
127
+ processor_cls, BaseProcessor
128
+ ), "All processors must inherit BaseProcessor class"
129
+ if name in cls.mapping["processor_name_mapping"]:
130
+ raise KeyError(
131
+ "Name '{}' already registered for {}.".format(
132
+ name, cls.mapping["processor_name_mapping"][name]
133
+ )
134
+ )
135
+ cls.mapping["processor_name_mapping"][name] = processor_cls
136
+ return processor_cls
137
+
138
+ return wrap
139
+
140
+ @classmethod
141
+ def register_lr_scheduler(cls, name):
142
+ r"""Register a model to registry with key 'name'
143
+
144
+ Args:
145
+ name: Key with which the task will be registered.
146
+
147
+ Usage:
148
+
149
+ from video_llama.common.registry import registry
150
+ """
151
+
152
+ def wrap(lr_sched_cls):
153
+ if name in cls.mapping["lr_scheduler_name_mapping"]:
154
+ raise KeyError(
155
+ "Name '{}' already registered for {}.".format(
156
+ name, cls.mapping["lr_scheduler_name_mapping"][name]
157
+ )
158
+ )
159
+ cls.mapping["lr_scheduler_name_mapping"][name] = lr_sched_cls
160
+ return lr_sched_cls
161
+
162
+ return wrap
163
+
164
+ @classmethod
165
+ def register_runner(cls, name):
166
+ r"""Register a model to registry with key 'name'
167
+
168
+ Args:
169
+ name: Key with which the task will be registered.
170
+
171
+ Usage:
172
+
173
+ from video_llama.common.registry import registry
174
+ """
175
+
176
+ def wrap(runner_cls):
177
+ if name in cls.mapping["runner_name_mapping"]:
178
+ raise KeyError(
179
+ "Name '{}' already registered for {}.".format(
180
+ name, cls.mapping["runner_name_mapping"][name]
181
+ )
182
+ )
183
+ cls.mapping["runner_name_mapping"][name] = runner_cls
184
+ return runner_cls
185
+
186
+ return wrap
187
+
188
+ @classmethod
189
+ def register_path(cls, name, path):
190
+ r"""Register a path to registry with key 'name'
191
+
192
+ Args:
193
+ name: Key with which the path will be registered.
194
+
195
+ Usage:
196
+
197
+ from video_llama.common.registry import registry
198
+ """
199
+ assert isinstance(path, str), "All path must be str."
200
+ if name in cls.mapping["paths"]:
201
+ raise KeyError("Name '{}' already registered.".format(name))
202
+ cls.mapping["paths"][name] = path
203
+
204
+ @classmethod
205
+ def register(cls, name, obj):
206
+ r"""Register an item to registry with key 'name'
207
+
208
+ Args:
209
+ name: Key with which the item will be registered.
210
+
211
+ Usage::
212
+
213
+ from video_llama.common.registry import registry
214
+
215
+ registry.register("config", {})
216
+ """
217
+ path = name.split(".")
218
+ current = cls.mapping["state"]
219
+
220
+ for part in path[:-1]:
221
+ if part not in current:
222
+ current[part] = {}
223
+ current = current[part]
224
+
225
+ current[path[-1]] = obj
226
+
227
+ # @classmethod
228
+ # def get_trainer_class(cls, name):
229
+ # return cls.mapping["trainer_name_mapping"].get(name, None)
230
+
231
+ @classmethod
232
+ def get_builder_class(cls, name):
233
+ return cls.mapping["builder_name_mapping"].get(name, None)
234
+
235
+ @classmethod
236
+ def get_model_class(cls, name):
237
+ return cls.mapping["model_name_mapping"].get(name, None)
238
+
239
+ @classmethod
240
+ def get_task_class(cls, name):
241
+ return cls.mapping["task_name_mapping"].get(name, None)
242
+
243
+ @classmethod
244
+ def get_processor_class(cls, name):
245
+ return cls.mapping["processor_name_mapping"].get(name, None)
246
+
247
+ @classmethod
248
+ def get_lr_scheduler_class(cls, name):
249
+ return cls.mapping["lr_scheduler_name_mapping"].get(name, None)
250
+
251
+ @classmethod
252
+ def get_runner_class(cls, name):
253
+ return cls.mapping["runner_name_mapping"].get(name, None)
254
+
255
+ @classmethod
256
+ def list_runners(cls):
257
+ return sorted(cls.mapping["runner_name_mapping"].keys())
258
+
259
+ @classmethod
260
+ def list_models(cls):
261
+ return sorted(cls.mapping["model_name_mapping"].keys())
262
+
263
+ @classmethod
264
+ def list_tasks(cls):
265
+ return sorted(cls.mapping["task_name_mapping"].keys())
266
+
267
+ @classmethod
268
+ def list_processors(cls):
269
+ return sorted(cls.mapping["processor_name_mapping"].keys())
270
+
271
+ @classmethod
272
+ def list_lr_schedulers(cls):
273
+ return sorted(cls.mapping["lr_scheduler_name_mapping"].keys())
274
+
275
+ @classmethod
276
+ def list_datasets(cls):
277
+ return sorted(cls.mapping["builder_name_mapping"].keys())
278
+
279
+ @classmethod
280
+ def get_path(cls, name):
281
+ return cls.mapping["paths"].get(name, None)
282
+
283
+ @classmethod
284
+ def get(cls, name, default=None, no_warning=False):
285
+ r"""Get an item from registry with key 'name'
286
+
287
+ Args:
288
+ name (string): Key whose value needs to be retrieved.
289
+ default: If passed and key is not in registry, default value will
290
+ be returned with a warning. Default: None
291
+ no_warning (bool): If passed as True, warning when key doesn't exist
292
+ will not be generated. Useful for MMF's
293
+ internal operations. Default: False
294
+ """
295
+ original_name = name
296
+ name = name.split(".")
297
+ value = cls.mapping["state"]
298
+ for subname in name:
299
+ value = value.get(subname, default)
300
+ if value is default:
301
+ break
302
+
303
+ if (
304
+ "writer" in cls.mapping["state"]
305
+ and value == default
306
+ and no_warning is False
307
+ ):
308
+ cls.mapping["state"]["writer"].warning(
309
+ "Key {} is not present in registry, returning default value "
310
+ "of {}".format(original_name, default)
311
+ )
312
+ return value
313
+
314
+ @classmethod
315
+ def unregister(cls, name):
316
+ r"""Remove an item from registry with key 'name'
317
+
318
+ Args:
319
+ name: Key which needs to be removed.
320
+ Usage::
321
+
322
+ from mmf.common.registry import registry
323
+
324
+ config = registry.unregister("config")
325
+ """
326
+ return cls.mapping["state"].pop(name, None)
327
+
328
+
329
+ registry = Registry()
global_local/common/utils.py ADDED
@@ -0,0 +1,424 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import io
9
+ import json
10
+ import logging
11
+ import os
12
+ import pickle
13
+ import re
14
+ import shutil
15
+ import urllib
16
+ import urllib.error
17
+ import urllib.request
18
+ from typing import Optional
19
+ from urllib.parse import urlparse
20
+
21
+ import numpy as np
22
+ import pandas as pd
23
+ import yaml
24
+ from iopath.common.download import download
25
+ from iopath.common.file_io import file_lock, g_pathmgr
26
+ from global_local.common.registry import registry
27
+ from torch.utils.model_zoo import tqdm
28
+ from torchvision.datasets.utils import (
29
+ check_integrity,
30
+ download_file_from_google_drive,
31
+ extract_archive,
32
+ )
33
+
34
+
35
+ def now():
36
+ from datetime import datetime
37
+
38
+ return datetime.now().strftime("%Y%m%d%H%M")[:-1]
39
+
40
+
41
+ def is_url(url_or_filename):
42
+ parsed = urlparse(url_or_filename)
43
+ return parsed.scheme in ("http", "https")
44
+
45
+
46
+ def get_cache_path(rel_path):
47
+ return os.path.expanduser(os.path.join(registry.get_path("cache_root"), rel_path))
48
+
49
+
50
+ def get_abs_path(rel_path):
51
+ return os.path.join(registry.get_path("library_root"), rel_path)
52
+
53
+
54
+ def load_json(filename):
55
+ with open(filename, "r") as f:
56
+ return json.load(f)
57
+
58
+
59
+ # The following are adapted from torchvision and vissl
60
+ # torchvision: https://github.com/pytorch/vision
61
+ # vissl: https://github.com/facebookresearch/vissl/blob/main/vissl/utils/download.py
62
+
63
+
64
+ def makedir(dir_path):
65
+ """
66
+ Create the directory if it does not exist.
67
+ """
68
+ is_success = False
69
+ try:
70
+ if not g_pathmgr.exists(dir_path):
71
+ g_pathmgr.mkdirs(dir_path)
72
+ is_success = True
73
+ except BaseException:
74
+ print(f"Error creating directory: {dir_path}")
75
+ return is_success
76
+
77
+
78
+ def get_redirected_url(url: str):
79
+ """
80
+ Given a URL, returns the URL it redirects to or the
81
+ original URL in case of no indirection
82
+ """
83
+ import requests
84
+
85
+ with requests.Session() as session:
86
+ with session.get(url, stream=True, allow_redirects=True) as response:
87
+ if response.history:
88
+ return response.url
89
+ else:
90
+ return url
91
+
92
+
93
+ def to_google_drive_download_url(view_url: str) -> str:
94
+ """
95
+ Utility function to transform a view URL of google drive
96
+ to a download URL for google drive
97
+ Example input:
98
+ https://drive.google.com/file/d/137RyRjvTBkBiIfeYBNZBtViDHQ6_Ewsp/view
99
+ Example output:
100
+ https://drive.google.com/uc?export=download&id=137RyRjvTBkBiIfeYBNZBtViDHQ6_Ewsp
101
+ """
102
+ splits = view_url.split("/")
103
+ assert splits[-1] == "view"
104
+ file_id = splits[-2]
105
+ return f"https://drive.google.com/uc?export=download&id={file_id}"
106
+
107
+
108
+ def download_google_drive_url(url: str, output_path: str, output_file_name: str):
109
+ """
110
+ Download a file from google drive
111
+ Downloading an URL from google drive requires confirmation when
112
+ the file of the size is too big (google drive notifies that
113
+ anti-viral checks cannot be performed on such files)
114
+ """
115
+ import requests
116
+
117
+ with requests.Session() as session:
118
+
119
+ # First get the confirmation token and append it to the URL
120
+ with session.get(url, stream=True, allow_redirects=True) as response:
121
+ for k, v in response.cookies.items():
122
+ if k.startswith("download_warning"):
123
+ url = url + "&confirm=" + v
124
+
125
+ # Then download the content of the file
126
+ with session.get(url, stream=True, verify=True) as response:
127
+ makedir(output_path)
128
+ path = os.path.join(output_path, output_file_name)
129
+ total_size = int(response.headers.get("Content-length", 0))
130
+ with open(path, "wb") as file:
131
+ from tqdm import tqdm
132
+
133
+ with tqdm(total=total_size) as progress_bar:
134
+ for block in response.iter_content(
135
+ chunk_size=io.DEFAULT_BUFFER_SIZE
136
+ ):
137
+ file.write(block)
138
+ progress_bar.update(len(block))
139
+
140
+
141
+ def _get_google_drive_file_id(url: str) -> Optional[str]:
142
+ parts = urlparse(url)
143
+
144
+ if re.match(r"(drive|docs)[.]google[.]com", parts.netloc) is None:
145
+ return None
146
+
147
+ match = re.match(r"/file/d/(?P<id>[^/]*)", parts.path)
148
+ if match is None:
149
+ return None
150
+
151
+ return match.group("id")
152
+
153
+
154
+ def _urlretrieve(url: str, filename: str, chunk_size: int = 1024) -> None:
155
+ with open(filename, "wb") as fh:
156
+ with urllib.request.urlopen(
157
+ urllib.request.Request(url, headers={"User-Agent": "vissl"})
158
+ ) as response:
159
+ with tqdm(total=response.length) as pbar:
160
+ for chunk in iter(lambda: response.read(chunk_size), ""):
161
+ if not chunk:
162
+ break
163
+ pbar.update(chunk_size)
164
+ fh.write(chunk)
165
+
166
+
167
+ def download_url(
168
+ url: str,
169
+ root: str,
170
+ filename: Optional[str] = None,
171
+ md5: Optional[str] = None,
172
+ ) -> None:
173
+ """Download a file from a url and place it in root.
174
+ Args:
175
+ url (str): URL to download file from
176
+ root (str): Directory to place downloaded file in
177
+ filename (str, optional): Name to save the file under.
178
+ If None, use the basename of the URL.
179
+ md5 (str, optional): MD5 checksum of the download. If None, do not check
180
+ """
181
+ root = os.path.expanduser(root)
182
+ if not filename:
183
+ filename = os.path.basename(url)
184
+ fpath = os.path.join(root, filename)
185
+
186
+ makedir(root)
187
+
188
+ # check if file is already present locally
189
+ if check_integrity(fpath, md5):
190
+ print("Using downloaded and verified file: " + fpath)
191
+ return
192
+
193
+ # expand redirect chain if needed
194
+ url = get_redirected_url(url)
195
+
196
+ # check if file is located on Google Drive
197
+ file_id = _get_google_drive_file_id(url)
198
+ if file_id is not None:
199
+ return download_file_from_google_drive(file_id, root, filename, md5)
200
+
201
+ # download the file
202
+ try:
203
+ print("Downloading " + url + " to " + fpath)
204
+ _urlretrieve(url, fpath)
205
+ except (urllib.error.URLError, IOError) as e: # type: ignore[attr-defined]
206
+ if url[:5] == "https":
207
+ url = url.replace("https:", "http:")
208
+ print(
209
+ "Failed download. Trying https -> http instead."
210
+ " Downloading " + url + " to " + fpath
211
+ )
212
+ _urlretrieve(url, fpath)
213
+ else:
214
+ raise e
215
+
216
+ # check integrity of downloaded file
217
+ if not check_integrity(fpath, md5):
218
+ raise RuntimeError("File not found or corrupted.")
219
+
220
+
221
+ def download_and_extract_archive(
222
+ url: str,
223
+ download_root: str,
224
+ extract_root: Optional[str] = None,
225
+ filename: Optional[str] = None,
226
+ md5: Optional[str] = None,
227
+ remove_finished: bool = False,
228
+ ) -> None:
229
+ download_root = os.path.expanduser(download_root)
230
+ if extract_root is None:
231
+ extract_root = download_root
232
+ if not filename:
233
+ filename = os.path.basename(url)
234
+
235
+ download_url(url, download_root, filename, md5)
236
+
237
+ archive = os.path.join(download_root, filename)
238
+ print("Extracting {} to {}".format(archive, extract_root))
239
+ extract_archive(archive, extract_root, remove_finished)
240
+
241
+
242
+ def cache_url(url: str, cache_dir: str) -> str:
243
+ """
244
+ This implementation downloads the remote resource and caches it locally.
245
+ The resource will only be downloaded if not previously requested.
246
+ """
247
+ parsed_url = urlparse(url)
248
+ dirname = os.path.join(cache_dir, os.path.dirname(parsed_url.path.lstrip("/")))
249
+ makedir(dirname)
250
+ filename = url.split("/")[-1]
251
+ cached = os.path.join(dirname, filename)
252
+ with file_lock(cached):
253
+ if not os.path.isfile(cached):
254
+ logging.info(f"Downloading {url} to {cached} ...")
255
+ cached = download(url, dirname, filename=filename)
256
+ logging.info(f"URL {url} cached in {cached}")
257
+ return cached
258
+
259
+
260
+ # TODO (prigoyal): convert this into RAII-style API
261
+ def create_file_symlink(file1, file2):
262
+ """
263
+ Simply create the symlinks for a given file1 to file2.
264
+ Useful during model checkpointing to symlinks to the
265
+ latest successful checkpoint.
266
+ """
267
+ try:
268
+ if g_pathmgr.exists(file2):
269
+ g_pathmgr.rm(file2)
270
+ g_pathmgr.symlink(file1, file2)
271
+ except Exception as e:
272
+ logging.info(f"Could NOT create symlink. Error: {e}")
273
+
274
+
275
+ def save_file(data, filename, append_to_json=True, verbose=True):
276
+ """
277
+ Common i/o utility to handle saving data to various file formats.
278
+ Supported:
279
+ .pkl, .pickle, .npy, .json
280
+ Specifically for .json, users have the option to either append (default)
281
+ or rewrite by passing in Boolean value to append_to_json.
282
+ """
283
+ if verbose:
284
+ logging.info(f"Saving data to file: {filename}")
285
+ file_ext = os.path.splitext(filename)[1]
286
+ if file_ext in [".pkl", ".pickle"]:
287
+ with g_pathmgr.open(filename, "wb") as fopen:
288
+ pickle.dump(data, fopen, pickle.HIGHEST_PROTOCOL)
289
+ elif file_ext == ".npy":
290
+ with g_pathmgr.open(filename, "wb") as fopen:
291
+ np.save(fopen, data)
292
+ elif file_ext == ".json":
293
+ if append_to_json:
294
+ with g_pathmgr.open(filename, "a") as fopen:
295
+ fopen.write(json.dumps(data, sort_keys=True) + "\n")
296
+ fopen.flush()
297
+ else:
298
+ with g_pathmgr.open(filename, "w") as fopen:
299
+ fopen.write(json.dumps(data, sort_keys=True) + "\n")
300
+ fopen.flush()
301
+ elif file_ext == ".yaml":
302
+ with g_pathmgr.open(filename, "w") as fopen:
303
+ dump = yaml.dump(data)
304
+ fopen.write(dump)
305
+ fopen.flush()
306
+ else:
307
+ raise Exception(f"Saving {file_ext} is not supported yet")
308
+
309
+ if verbose:
310
+ logging.info(f"Saved data to file: {filename}")
311
+
312
+
313
+ def load_file(filename, mmap_mode=None, verbose=True, allow_pickle=False):
314
+ """
315
+ Common i/o utility to handle loading data from various file formats.
316
+ Supported:
317
+ .pkl, .pickle, .npy, .json
318
+ For the npy files, we support reading the files in mmap_mode.
319
+ If the mmap_mode of reading is not successful, we load data without the
320
+ mmap_mode.
321
+ """
322
+ if verbose:
323
+ logging.info(f"Loading data from file: {filename}")
324
+
325
+ file_ext = os.path.splitext(filename)[1]
326
+ if file_ext == ".txt":
327
+ with g_pathmgr.open(filename, "r") as fopen:
328
+ data = fopen.readlines()
329
+ elif file_ext in [".pkl", ".pickle"]:
330
+ with g_pathmgr.open(filename, "rb") as fopen:
331
+ data = pickle.load(fopen, encoding="latin1")
332
+ elif file_ext == ".npy":
333
+ if mmap_mode:
334
+ try:
335
+ with g_pathmgr.open(filename, "rb") as fopen:
336
+ data = np.load(
337
+ fopen,
338
+ allow_pickle=allow_pickle,
339
+ encoding="latin1",
340
+ mmap_mode=mmap_mode,
341
+ )
342
+ except ValueError as e:
343
+ logging.info(
344
+ f"Could not mmap {filename}: {e}. Trying without g_pathmgr"
345
+ )
346
+ data = np.load(
347
+ filename,
348
+ allow_pickle=allow_pickle,
349
+ encoding="latin1",
350
+ mmap_mode=mmap_mode,
351
+ )
352
+ logging.info("Successfully loaded without g_pathmgr")
353
+ except Exception:
354
+ logging.info("Could not mmap without g_pathmgr. Trying without mmap")
355
+ with g_pathmgr.open(filename, "rb") as fopen:
356
+ data = np.load(fopen, allow_pickle=allow_pickle, encoding="latin1")
357
+ else:
358
+ with g_pathmgr.open(filename, "rb") as fopen:
359
+ data = np.load(fopen, allow_pickle=allow_pickle, encoding="latin1")
360
+ elif file_ext == ".json":
361
+ with g_pathmgr.open(filename, "r") as fopen:
362
+ data = json.load(fopen)
363
+ elif file_ext == ".yaml":
364
+ with g_pathmgr.open(filename, "r") as fopen:
365
+ data = yaml.load(fopen, Loader=yaml.FullLoader)
366
+ elif file_ext == ".csv":
367
+ with g_pathmgr.open(filename, "r") as fopen:
368
+ data = pd.read_csv(fopen)
369
+ else:
370
+ raise Exception(f"Reading from {file_ext} is not supported yet")
371
+ return data
372
+
373
+
374
+ def abspath(resource_path: str):
375
+ """
376
+ Make a path absolute, but take into account prefixes like
377
+ "http://" or "manifold://"
378
+ """
379
+ regex = re.compile(r"^\w+://")
380
+ if regex.match(resource_path) is None:
381
+ return os.path.abspath(resource_path)
382
+ else:
383
+ return resource_path
384
+
385
+
386
+ def makedir(dir_path):
387
+ """
388
+ Create the directory if it does not exist.
389
+ """
390
+ is_success = False
391
+ try:
392
+ if not g_pathmgr.exists(dir_path):
393
+ g_pathmgr.mkdirs(dir_path)
394
+ is_success = True
395
+ except BaseException:
396
+ logging.info(f"Error creating directory: {dir_path}")
397
+ return is_success
398
+
399
+
400
+ def is_url(input_url):
401
+ """
402
+ Check if an input string is a url. look for http(s):// and ignoring the case
403
+ """
404
+ is_url = re.match(r"^(?:http)s?://", input_url, re.IGNORECASE) is not None
405
+ return is_url
406
+
407
+
408
+ def cleanup_dir(dir):
409
+ """
410
+ Utility for deleting a directory. Useful for cleaning the storage space
411
+ that contains various training artifacts like checkpoints, data etc.
412
+ """
413
+ if os.path.exists(dir):
414
+ logging.info(f"Deleting directory: {dir}")
415
+ shutil.rmtree(dir)
416
+ logging.info(f"Deleted contents of directory: {dir}")
417
+
418
+
419
+ def get_file_size(filename):
420
+ """
421
+ Given a file, get the size of file in MB
422
+ """
423
+ size_in_mb = os.path.getsize(filename) / float(1024**2)
424
+ return size_in_mb
global_local/configs/datasets/cc_sbu/align.yaml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ datasets:
2
+ cc_sbu_align:
3
+ data_type: images
4
+ build_info:
5
+ storage: /path/to/cc_sbu_align_dataset
global_local/configs/datasets/cc_sbu/defaults.yaml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ datasets:
2
+ cc_sbu:
3
+ data_type: images
4
+ build_info:
5
+ storage: /path/to/cc_sbu_dataset/{00000..00001}.tar
global_local/configs/datasets/instruct/llava_instruct.yaml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ datasets:
2
+ llava_instruct:
3
+ data_type: image
4
+ build_info:
5
+ anno_dir: /path/llava_instruct_150k.json
6
+ videos_dir: /path/train2014/train2014/
global_local/configs/datasets/instruct/webvid_instruct.yaml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ datasets:
2
+ webvid_instruct:
3
+ data_type: image
4
+ build_info:
5
+ anno_dir: /path/webvid_align/videochat_instruct_11k.json
6
+ videos_dir: /path/webvid_align/videos/
global_local/configs/datasets/laion/defaults.yaml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ datasets:
2
+ laion:
3
+ data_type: images
4
+ build_info:
5
+ storage: path/laion/laion_dataset/{00000..00001}.tar
global_local/configs/datasets/webvid/defaults.yaml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ datasets:
2
+ webvid:
3
+ data_type: video
4
+ build_info:
5
+ anno_dir: path/webvid/webvid_tain_data/annotations/
6
+ videos_dir: path//webvid/webvid_tain_data/videos/
global_local/configs/default.yaml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ env:
2
+ # For default users
3
+ # cache_root: "cache"
4
+ # For internal use with persistent storage
5
+ cache_root: "/export/home/.cache/minigpt4"
global_local/configs/models/minigpt4.yaml ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ arch: mini_gpt4
3
+
4
+ # vit encoder
5
+ image_size: 224
6
+ drop_path_rate: 0
7
+ use_grad_checkpoint: False
8
+ vit_precision: "fp16"
9
+ freeze_vit: True
10
+ freeze_qformer: True
11
+
12
+ # Q-Former
13
+ num_query_token: 32
14
+
15
+ # Vicuna
16
+ llama_model: "ckpt/vicuna-13b/"
17
+
18
+ # generation configs
19
+ prompt: ""
20
+
21
+ preprocess:
22
+ vis_processor:
23
+ train:
24
+ name: "blip2_image_train"
25
+ image_size: 224
26
+ eval:
27
+ name: "blip2_image_eval"
28
+ image_size: 224
29
+ text_processor:
30
+ train:
31
+ name: "blip_caption"
32
+ eval:
33
+ name: "blip_caption"
global_local/configs/models/video_llama.yaml ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ arch: video_llama
3
+
4
+ # vit encoder
5
+ image_size: 224
6
+ drop_path_rate: 0
7
+ use_grad_checkpoint: False
8
+ vit_precision: "fp16"
9
+ freeze_vit: True
10
+ freeze_qformer: True
11
+
12
+ # Q-Former
13
+ num_query_token: 32
14
+
15
+ # Vicuna
16
+ llama_model: "/projectnb/ivc-ml/samarth/projects/misc/minigpt-4-chat-models/Llama-2-7b-chat-hf"
17
+
18
+ # generation configs
19
+ prompt: ""
20
+
21
+ preprocess:
22
+ vis_processor:
23
+ train:
24
+ name: "alpro_video_train"
25
+ image_size: 224
26
+ n_frms: 8
27
+ eval:
28
+ name: "alpro_video_eval"
29
+ image_size: 224
30
+ n_frms: 8
31
+ text_processor:
32
+ train:
33
+ name: "blip_caption"
34
+ eval:
35
+ name: "blip_caption"
36
+
global_local/conversation/__init__.py ADDED
File without changes
global_local/conversation/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (225 Bytes). View file
 
global_local/conversation/__pycache__/conversation_video.cpython-39.pyc ADDED
Binary file (12.2 kB). View file
 
global_local/conversation/conversation_video.py ADDED
@@ -0,0 +1,404 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Conversation prompt template of Video-LLaMA.
3
+ Adapted from: https://github.com/Vision-CAIR/MiniGPT-4/blob/main/minigpt4/conversation/conversation.py
4
+ """
5
+ import argparse
6
+ import time
7
+ from PIL import Image
8
+ import sys
9
+ import os
10
+ import torch
11
+ from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaTokenizer
12
+ from transformers import StoppingCriteria, StoppingCriteriaList
13
+
14
+ import dataclasses
15
+ from enum import auto, Enum
16
+ from typing import List, Tuple, Any
17
+ import os
18
+ import sys
19
+ from global_local.common.registry import registry
20
+ from global_local.processors.video_processor import ToTHWC,ToUint8,load_video
21
+ from global_local.processors import Blip2ImageEvalProcessor
22
+
23
+ #from video_llama.models.ImageBind.data import load_and_transform_audio_data
24
+ class SeparatorStyle(Enum):
25
+ """Different separator style."""
26
+ SINGLE = auto()
27
+ TWO = auto()
28
+ LLAMA_2 = auto()
29
+
30
+
31
+ @dataclasses.dataclass
32
+ class Conversation:
33
+ """A class that keeps all conversation history."""
34
+ system: str
35
+ roles: List[str]
36
+ messages: List[List[str]]
37
+ offset: int
38
+ # system_img: List[Image.Image] = []
39
+ sep_style: SeparatorStyle = SeparatorStyle.SINGLE
40
+ sep: str = "###"
41
+ sep2: str = None
42
+
43
+ skip_next: bool = False
44
+ conv_id: Any = None
45
+
46
+ def get_prompt(self):
47
+ if self.sep_style == SeparatorStyle.SINGLE:
48
+ ret = self.system + self.sep
49
+ for role, message in self.messages:
50
+ if message:
51
+ ret += role + ": " + message + self.sep
52
+ else:
53
+ ret += role + ":"
54
+ return ret
55
+ elif self.sep_style == SeparatorStyle.TWO:
56
+ seps = [self.sep, self.sep2]
57
+ ret = self.system + seps[0]
58
+ for i, (role, message) in enumerate(self.messages):
59
+ if message:
60
+ ret += role + ": " + message + seps[i % 2]
61
+ else:
62
+ ret += role + ":"
63
+ return ret
64
+ elif self.sep_style == SeparatorStyle.LLAMA_2:
65
+ wrap_sys = lambda msg: f"<<SYS>>\n{msg}\n<</SYS>>\n\n"
66
+ wrap_inst = lambda msg: f"[INST] {msg} [/INST]"
67
+ ret = ""
68
+
69
+ for i, (role, message) in enumerate(self.messages):
70
+ if i == 0:
71
+ assert message, "first message should not be none"
72
+ assert role == self.roles[0], "first message should come from user"
73
+ if message:
74
+ if type(message) is tuple:
75
+ message, _, _ = message
76
+ if i == 0: message = wrap_sys(self.system) + message
77
+ if i % 2 == 0:
78
+ message = wrap_inst(message)
79
+ ret += self.sep + message
80
+ else:
81
+ ret += " " + message + " " + self.sep2
82
+ else:
83
+ ret += ""
84
+ ret = ret.lstrip(self.sep)
85
+ return ret
86
+ else:
87
+ raise ValueError(f"Invalid style: {self.sep_style}")
88
+
89
+ def append_message(self, role, message):
90
+ self.messages.append([role, message])
91
+
92
+ def to_gradio_chatbot(self):
93
+ ret = []
94
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
95
+ if i % 2 == 0:
96
+ ret.append([msg, None])
97
+ else:
98
+ ret[-1][-1] = msg
99
+ return ret
100
+
101
+ def copy(self):
102
+ return Conversation(
103
+ system=self.system,
104
+ # system_img=self.system_img,
105
+ roles=self.roles,
106
+ messages=[[x, y] for x, y in self.messages],
107
+ offset=self.offset,
108
+ sep_style=self.sep_style,
109
+ sep=self.sep,
110
+ sep2=self.sep2,
111
+ conv_id=self.conv_id)
112
+
113
+ def dict(self):
114
+ return {
115
+ "system": self.system,
116
+ # "system_img": self.system_img,
117
+ "roles": self.roles,
118
+ "messages": self.messages,
119
+ "offset": self.offset,
120
+ "sep": self.sep,
121
+ "sep2": self.sep2,
122
+ "conv_id": self.conv_id,
123
+ }
124
+
125
+
126
+ class StoppingCriteriaSub(StoppingCriteria):
127
+
128
+ def __init__(self, stops=[], encounters=1):
129
+ super().__init__()
130
+ self.stops = stops
131
+
132
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
133
+ for stop in self.stops:
134
+ if torch.all((stop == input_ids[0][-len(stop):])).item():
135
+ return True
136
+
137
+ return False
138
+
139
+
140
+ CONV_VISION = Conversation(
141
+ system="Give the following image: <Img>ImageContent</Img>. "
142
+ "You will be able to see the image once I provide it to you. Please answer my questions.",
143
+ roles=("Human", "Assistant"),
144
+ messages=[],
145
+ offset=0,
146
+ sep_style=SeparatorStyle.SINGLE,
147
+ sep="###",
148
+ )
149
+
150
+ default_conversation = Conversation(
151
+ system="",
152
+ roles=("Human", "Assistant"),
153
+ messages=[],
154
+ offset=0,
155
+ sep_style=SeparatorStyle.SINGLE,
156
+ sep="###",
157
+ )
158
+ conv_llava_llama_2 = Conversation(
159
+ system="You are a helpful language and vision assistant. "
160
+ "You are able to understand the visual content that the user provides, "
161
+ "and assist the user with a variety of tasks using natural language.",
162
+ roles=("USER", "ASSISTANT"),
163
+ messages=(),
164
+ offset=0,
165
+ sep_style=SeparatorStyle.LLAMA_2,
166
+ sep="<s>",
167
+ sep2="</s>",
168
+ )
169
+ class Chat:
170
+ def __init__(self, model, vis_processor, device='cuda:0'):
171
+ self.device = device
172
+ self.model = model
173
+ self.vis_processor = vis_processor
174
+ self.image_vis_processor = Blip2ImageEvalProcessor()
175
+ # stop_words_ids = [torch.tensor([835]).to(self.device),
176
+ # torch.tensor([2277, 29937]).to(self.device)] # '###' can be encoded in two different ways.
177
+ # self.stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
178
+
179
+ self.num_frames_per_clip = 16
180
+ self.num_segments = 4
181
+
182
+ def ask(self, text, conv):
183
+ if len(conv.messages) > 0 and conv.messages[-1][0] == conv.roles[0] \
184
+ and ('</Video>' in conv.messages[-1][1] or '</Image>' in conv.messages[-1][1]): # last message is image.
185
+ conv.messages[-1][1] = ' '.join([conv.messages[-1][1], text])
186
+ else:
187
+ conv.append_message(conv.roles[0], text)
188
+
189
+ def answer(self, conv, img_list, max_new_tokens=300, num_beams=1, min_length=1, top_p=0.9,
190
+ repetition_penalty=1.0, length_penalty=1, temperature=1.0, max_length=2000):
191
+ conv.append_message(conv.roles[1], None)
192
+ embs = self.get_context_emb(conv, img_list)
193
+
194
+ current_max_len = embs.shape[1] + max_new_tokens
195
+ if current_max_len - max_length > 0:
196
+ print('Warning: The number of tokens in current conversation exceeds the max length. '
197
+ 'The model will not see the contexts outside the range.')
198
+ begin_idx = max(0, current_max_len - max_length)
199
+
200
+ embs = embs[:, begin_idx:]
201
+ if conv.sep =="###":
202
+ stop_words_ids = [torch.tensor([835]).to(self.device),
203
+ torch.tensor([2277, 29937]).to(self.device)] # '###' can be encoded in two different ways.
204
+ stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
205
+ else:
206
+ stop_words_ids = [torch.tensor([2]).to(self.device)]
207
+ stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
208
+
209
+ # stopping_criteria
210
+ outputs = self.model.llama_model.generate(
211
+ inputs_embeds=embs,
212
+ max_new_tokens=max_new_tokens,
213
+ stopping_criteria=stopping_criteria,
214
+ num_beams=num_beams,
215
+ do_sample=True,
216
+ min_length=min_length,
217
+ top_p=top_p,
218
+ repetition_penalty=repetition_penalty,
219
+ length_penalty=length_penalty,
220
+ temperature=temperature,
221
+ )
222
+ output_token = outputs[0]
223
+ if output_token[0] == 0: # the model might output a unknow token <unk> at the beginning. remove it
224
+ output_token = output_token[1:]
225
+ if output_token[0] == 1: # some users find that there is a start token <s> at the beginning. remove it
226
+ output_token = output_token[1:]
227
+ output_text = self.model.llama_tokenizer.decode(output_token, add_special_tokens=False)
228
+ if conv.sep =="###":
229
+ output_text = output_text.split('###')[0] # remove the stop sign '###'
230
+ output_text = output_text.split('Assistant:')[-1].strip()
231
+ else:
232
+ output_text = output_text.split(conv.sep2)[0] # remove the stop sign '###'
233
+ output_text = output_text.split(conv.roles[1]+':')[-1].strip()
234
+ conv.messages[-1][1] = output_text
235
+ return output_text, output_token.cpu().numpy()
236
+
237
+ def upload_video(self, video_path, conv, img_list):
238
+
239
+ msg = ""
240
+ if isinstance(video_path, str): # is a video path
241
+ ext = os.path.splitext(video_path)[-1].lower()
242
+ print(video_path)
243
+ # image = self.vis_processor(image).unsqueeze(0).to(self.device)
244
+ video, msg = load_video(
245
+ video_path=video_path,
246
+ n_frms=8,
247
+ height=224,
248
+ width=224,
249
+ sampling ="uniform", return_msg = True
250
+ )
251
+ video = self.vis_processor.transform(video)
252
+ video = video.unsqueeze(0).to(self.device)
253
+ # print(image)
254
+ else:
255
+ raise NotImplementedError
256
+
257
+ try:
258
+ audio_flag = 1
259
+ audio = load_and_transform_audio_data([video_path],"cpu", clips_per_video=8)
260
+ audio = audio.to(self.device)
261
+ except :
262
+ print('no audio is found')
263
+ audio_flag = 0
264
+ finally:
265
+ if audio_flag == 1:
266
+ # image_emb, _ = self.model.encode_videoQformer_audiovideo(video,audio)
267
+ image_emb, _ = self.model.encode_videoQformer_visual(video)
268
+ audio_emb,_ = self.model.encode_audioQformer(audio)
269
+ img_list.append(audio_emb)
270
+ img_list.append(image_emb)
271
+ conv.system = ""
272
+ # conv.append_message(conv.roles[0], "The audio of this video is <Video><ImageHere></Video> ")
273
+ conv.append_message(conv.roles[0], "Close your eyes, open your ears and you imagine only based on the sound that: <ImageHere>. \
274
+ Close your ears, open your eyes and you see that <Video><ImageHere></Video>. \
275
+ Now answer my question based on what you have just seen and heard.")
276
+
277
+ else: # only vison no audio
278
+ # conv.system = "You can understand the video that the user provides. Follow the instructions carefully and explain your answers in detail."
279
+ image_emb, _ = self.model.encode_videoQformer_visual(video)
280
+ img_list.append(image_emb)
281
+ conv.append_message(conv.roles[0], "<Video><ImageHere></Video> "+ msg)
282
+ return "Received."
283
+
284
+ def upload_video_without_audio(self, video_path, conv, img_list):
285
+ msg = ""
286
+ if isinstance(video_path, str): # is a video path
287
+ ext = os.path.splitext(video_path)[-1].lower()
288
+ print(video_path)
289
+ # image = self.vis_processor(image).unsqueeze(0).to(self.device)
290
+ video, msg = load_video(
291
+ video_path=video_path,
292
+ n_frms=self.num_frames_per_clip*self.num_segments,
293
+ height=224,
294
+ width=224,
295
+ sampling ="uniform", return_msg = True
296
+ )
297
+
298
+ video = self.vis_processor.transform(video)
299
+ video = video.unsqueeze(0).to(self.device)
300
+ else:
301
+ raise NotImplementedError
302
+
303
+ # conv.system = "You can understand the video that the user provides. Follow the instructions carefully and explain your answers in detail."
304
+ #image_emb, _ = self.model.encode_videoQformer_visual(video)
305
+ image_emb, _ = self.process_video_frames(video)
306
+ img_list.append(image_emb)
307
+ conv.append_message(conv.roles[0], "<Video><ImageHere></Video> "+ msg)
308
+
309
+ return "Received."
310
+
311
+ def process_video_frames(self, all_frames):
312
+ total_num_frames = self.num_frames_per_clip * self.num_segments
313
+ global_clip_indices = torch.linspace(0, total_num_frames-1, steps=self.num_frames_per_clip)
314
+ short_window_indices = torch.linspace(0, total_num_frames-1, steps=self.num_frames_per_clip * self.num_segments)
315
+
316
+ global_processed_frames = []
317
+ for i in global_clip_indices:
318
+ i = int(i)
319
+ curr = all_frames[:, :, i]
320
+ #curr = np.uint8(all_frames[i])
321
+ #curr = frame_transform(Image.fromarray(curr))
322
+ global_processed_frames.append(curr)
323
+ global_processed_frames = torch.stack(global_processed_frames, dim=2)
324
+
325
+ '''if len(global_processed_frames) < args.num_frames_per_clip:
326
+ diff = args.num_frames_per_clip - len(global_processed_frames)
327
+ pad = global_processed_frames[-1].unsqueeze(0).repeat(diff, 1, 1, 1)
328
+ global_processed_frames = torch.cat((global_processed_frames, pad), dim=0)'''
329
+
330
+ short_window_processed_frames = []
331
+ for i in short_window_indices:
332
+ i = int(i)
333
+ curr = all_frames[:, :, i]
334
+ #curr = np.uint8(all_frames[i])
335
+ #curr = frame_transform(Image.fromarray(curr))
336
+ short_window_processed_frames.append(curr)
337
+ short_window_processed_frames = torch.stack(short_window_processed_frames, dim=2)
338
+
339
+ '''if len(short_window_processed_frames) < args.num_frames_per_clip * args.num_segments:
340
+ diff = args.num_frames_per_clip * args.num_segments - len(short_window_processed_frames)
341
+ pad = short_window_processed_frames[-1].unsqueeze(0).repeat(diff, 1, 1, 1)
342
+ short_window_processed_frames = torch.cat((short_window_processed_frames, pad), dim=0)'''
343
+
344
+ global_attn_mask = torch.zeros((self.num_frames_per_clip))
345
+ global_attn_mask[:global_processed_frames.size(2)] = True
346
+
347
+ short_window_attn_mask = torch.zeros((self.num_frames_per_clip * self.num_segments))
348
+ short_window_attn_mask[:short_window_processed_frames.size(2)] = True
349
+
350
+ global_processed_frames = global_processed_frames.permute((0, 2, 1, 3, 4)).cuda()
351
+ short_window_processed_frames = short_window_processed_frames.permute((0, 2, 1, 3, 4)).cuda()
352
+ global_frame_attn_mask = global_attn_mask.unsqueeze(0).cuda()
353
+ segments_frame_attn_mask = short_window_attn_mask.unsqueeze(0).cuda()
354
+
355
+ with torch.no_grad():
356
+ samples = {'global_video': global_processed_frames, 'global_frame_attn_mask': global_frame_attn_mask, 'segments_video': short_window_processed_frames, 'segments_frame_attn_mask': segments_frame_attn_mask}
357
+ merged_video_embeds, merged_video_embeds_mask = self.model.compute_merged_video_embeds(samples)
358
+
359
+ return merged_video_embeds, merged_video_embeds_mask
360
+
361
+ def upload_img(self, image, conv, img_list):
362
+
363
+ msg = ""
364
+ if isinstance(image, str): # is a image path
365
+ raw_image = Image.open(image).convert('RGB') # 增加一个时间维度
366
+ image = self.image_vis_processor(raw_image).unsqueeze(0).unsqueeze(2).to(self.device)
367
+ elif isinstance(image, Image.Image):
368
+ raw_image = image
369
+ image = self.image_vis_processor(raw_image).unsqueeze(0).unsqueeze(2).to(self.device)
370
+ elif isinstance(image, torch.Tensor):
371
+ if len(image.shape) == 3:
372
+ image = image.unsqueeze(0)
373
+ image = image.to(self.device)
374
+ else:
375
+ raise NotImplementedError
376
+
377
+ image_emb, _ = self.model.encode_videoQformer_visual(image)
378
+ img_list.append(image_emb)
379
+ # Todo msg=""
380
+ conv.append_message(conv.roles[0], "<Image><ImageHere></Image> "+ msg)
381
+
382
+ return "Received."
383
+
384
+ def get_context_emb(self, conv, img_list):
385
+ prompt = conv.get_prompt()
386
+ prompt_segs = prompt.split('<ImageHere>')
387
+ assert len(prompt_segs) == len(img_list) + 1, "Unmatched numbers of image placeholders and images."
388
+ seg_tokens = [
389
+ self.model.llama_tokenizer(
390
+ seg, return_tensors="pt", add_special_tokens=i == 0).to(self.device).input_ids
391
+ # only add bos to the first seg
392
+ for i, seg in enumerate(prompt_segs)
393
+ ]
394
+ seg_embs = [self.model.llama_model.model.embed_tokens(seg_t) for seg_t in seg_tokens]
395
+ mixed_embs = [emb for pair in zip(seg_embs[:-1], img_list) for emb in pair] + [seg_embs[-1]]
396
+ mixed_embs = torch.cat(mixed_embs, dim=1)
397
+
398
+ return mixed_embs
399
+
400
+ if __name__ =='__main__':
401
+ video_path = '/mnt/workspace/videoGPT/Video-LLaMA/examples/applausing.mp4'
402
+ # import torch.classes.torchaudio.ffmpeg_StreamReader
403
+ # ffmpeg_StreamReader(video_path)
404
+ load_and_transform_audio_data([video_path],"cpu", clips_per_video=8)
global_local/datasets/__init__.py ADDED
File without changes
global_local/datasets/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (221 Bytes). View file
 
global_local/datasets/__pycache__/data_utils.cpython-39.pyc ADDED
Binary file (6.04 kB). View file
 
global_local/datasets/builders/__init__.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ from global_local.datasets.builders.base_dataset_builder import load_dataset_config
9
+ from global_local.datasets.builders.image_text_pair_builder import (
10
+ CCSBUBuilder,
11
+ LaionBuilder,
12
+ CCSBUAlignBuilder
13
+ )
14
+ from global_local.datasets.builders.video_caption_builder import WebvidBuilder
15
+ from global_local.common.registry import registry
16
+ from global_local.datasets.builders.instruct_builder import WebvidInstruct_Builder,LlavaInstruct_Builder
17
+ __all__ = [
18
+ "CCSBUBuilder",
19
+ "LaionBuilder",
20
+ "CCSBUAlignBuilder",
21
+ "WebvidBuilder",
22
+ "LlavaInstruct_Builder",
23
+ "WebvidInstruct_Builder"
24
+
25
+ ]
26
+
27
+
28
+ def load_dataset(name, cfg_path=None, vis_path=None, data_type=None):
29
+ """
30
+ Example
31
+
32
+ >>> dataset = load_dataset("coco_caption", cfg=None)
33
+ >>> splits = dataset.keys()
34
+ >>> print([len(dataset[split]) for split in splits])
35
+
36
+ """
37
+ if cfg_path is None:
38
+ cfg = None
39
+ else:
40
+ cfg = load_dataset_config(cfg_path)
41
+
42
+ try:
43
+ builder = registry.get_builder_class(name)(cfg)
44
+ except TypeError:
45
+ print(
46
+ f"Dataset {name} not found. Available datasets:\n"
47
+ + ", ".join([str(k) for k in dataset_zoo.get_names()])
48
+ )
49
+ exit(1)
50
+
51
+ if vis_path is not None:
52
+ if data_type is None:
53
+ # use default data type in the config
54
+ data_type = builder.config.data_type
55
+
56
+ assert (
57
+ data_type in builder.config.build_info
58
+ ), f"Invalid data_type {data_type} for {name}."
59
+
60
+ builder.config.build_info.get(data_type).storage = vis_path
61
+
62
+ dataset = builder.build_datasets()
63
+ return dataset
64
+
65
+
66
+ class DatasetZoo:
67
+ def __init__(self) -> None:
68
+ self.dataset_zoo = {
69
+ k: list(v.DATASET_CONFIG_DICT.keys())
70
+ for k, v in sorted(registry.mapping["builder_name_mapping"].items())
71
+ }
72
+
73
+ def get_names(self):
74
+ return list(self.dataset_zoo.keys())
75
+
76
+
77
+ dataset_zoo = DatasetZoo()
global_local/datasets/builders/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (2.69 kB). View file
 
global_local/datasets/builders/__pycache__/base_dataset_builder.cpython-39.pyc ADDED
Binary file (6.15 kB). View file
 
global_local/datasets/builders/__pycache__/image_text_pair_builder.cpython-39.pyc ADDED
Binary file (3.06 kB). View file
 
global_local/datasets/builders/__pycache__/instruct_builder.cpython-39.pyc ADDED
Binary file (2.62 kB). View file
 
global_local/datasets/builders/__pycache__/video_caption_builder.cpython-39.pyc ADDED
Binary file (1.52 kB). View file
 
global_local/datasets/builders/base_dataset_builder.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This file is from
3
+ Copyright (c) 2022, salesforce.com, inc.
4
+ All rights reserved.
5
+ SPDX-License-Identifier: BSD-3-Clause
6
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
7
+ """
8
+
9
+ import logging
10
+ import os
11
+ import shutil
12
+ import warnings
13
+
14
+ from omegaconf import OmegaConf
15
+ import torch.distributed as dist
16
+ from torchvision.datasets.utils import download_url
17
+
18
+ import global_local.common.utils as utils
19
+ from global_local.common.dist_utils import is_dist_avail_and_initialized, is_main_process
20
+ from global_local.common.registry import registry
21
+ from global_local.processors.base_processor import BaseProcessor
22
+
23
+
24
+
25
+ class BaseDatasetBuilder:
26
+ train_dataset_cls, eval_dataset_cls = None, None
27
+
28
+ def __init__(self, cfg=None):
29
+ super().__init__()
30
+
31
+ if cfg is None:
32
+ # help to create datasets from default config.
33
+ self.config = load_dataset_config(self.default_config_path())
34
+ elif isinstance(cfg, str):
35
+ self.config = load_dataset_config(cfg)
36
+ else:
37
+ # when called from task.build_dataset()
38
+ self.config = cfg
39
+
40
+ self.data_type = self.config.data_type
41
+
42
+ self.vis_processors = {"train": BaseProcessor(), "eval": BaseProcessor()}
43
+ self.text_processors = {"train": BaseProcessor(), "eval": BaseProcessor()}
44
+
45
+ def build_datasets(self):
46
+ # download, split, etc...
47
+ # only called on 1 GPU/TPU in distributed
48
+
49
+ if is_main_process():
50
+ self._download_data()
51
+
52
+ if is_dist_avail_and_initialized():
53
+ dist.barrier()
54
+
55
+ # at this point, all the annotations and image/videos should be all downloaded to the specified locations.
56
+ logging.info("Building datasets...")
57
+ datasets = self.build() # dataset['train'/'val'/'test']
58
+
59
+ return datasets
60
+
61
+ def build_processors(self):
62
+ vis_proc_cfg = self.config.get("vis_processor")
63
+ txt_proc_cfg = self.config.get("text_processor")
64
+
65
+ if vis_proc_cfg is not None:
66
+ vis_train_cfg = vis_proc_cfg.get("train")
67
+ vis_eval_cfg = vis_proc_cfg.get("eval")
68
+
69
+ self.vis_processors["train"] = self._build_proc_from_cfg(vis_train_cfg)
70
+ self.vis_processors["eval"] = self._build_proc_from_cfg(vis_eval_cfg)
71
+
72
+ if txt_proc_cfg is not None:
73
+ txt_train_cfg = txt_proc_cfg.get("train")
74
+ txt_eval_cfg = txt_proc_cfg.get("eval")
75
+
76
+ self.text_processors["train"] = self._build_proc_from_cfg(txt_train_cfg)
77
+ self.text_processors["eval"] = self._build_proc_from_cfg(txt_eval_cfg)
78
+
79
+ @staticmethod
80
+ def _build_proc_from_cfg(cfg):
81
+ return (
82
+ registry.get_processor_class(cfg.name).from_config(cfg)
83
+ if cfg is not None
84
+ else None
85
+ )
86
+
87
+ @classmethod
88
+ def default_config_path(cls, type="default"):
89
+ return utils.get_abs_path(cls.DATASET_CONFIG_DICT[type])
90
+
91
+ def _download_data(self):
92
+ self._download_ann()
93
+ self._download_vis()
94
+
95
+ def _download_ann(self):
96
+ """
97
+ Download annotation files if necessary.
98
+ All the vision-language datasets should have annotations of unified format.
99
+
100
+ storage_path can be:
101
+ (1) relative/absolute: will be prefixed with env.cache_root to make full path if relative.
102
+ (2) basename/dirname: will be suffixed with base name of URL if dirname is provided.
103
+
104
+ Local annotation paths should be relative.
105
+ """
106
+ anns = self.config.build_info.annotations
107
+
108
+ splits = anns.keys()
109
+
110
+ cache_root = registry.get_path("cache_root")
111
+
112
+ for split in splits:
113
+ info = anns[split]
114
+
115
+ urls, storage_paths = info.get("url", None), info.storage
116
+
117
+ if isinstance(urls, str):
118
+ urls = [urls]
119
+ if isinstance(storage_paths, str):
120
+ storage_paths = [storage_paths]
121
+
122
+ assert len(urls) == len(storage_paths)
123
+
124
+ for url_or_filename, storage_path in zip(urls, storage_paths):
125
+ # if storage_path is relative, make it full by prefixing with cache_root.
126
+ if not os.path.isabs(storage_path):
127
+ storage_path = os.path.join(cache_root, storage_path)
128
+
129
+ dirname = os.path.dirname(storage_path)
130
+ if not os.path.exists(dirname):
131
+ os.makedirs(dirname)
132
+
133
+ if os.path.isfile(url_or_filename):
134
+ src, dst = url_or_filename, storage_path
135
+ if not os.path.exists(dst):
136
+ shutil.copyfile(src=src, dst=dst)
137
+ else:
138
+ logging.info("Using existing file {}.".format(dst))
139
+ else:
140
+ if os.path.isdir(storage_path):
141
+ # if only dirname is provided, suffix with basename of URL.
142
+ raise ValueError(
143
+ "Expecting storage_path to be a file path, got directory {}".format(
144
+ storage_path
145
+ )
146
+ )
147
+ else:
148
+ filename = os.path.basename(storage_path)
149
+
150
+ download_url(url=url_or_filename, root=dirname, filename=filename)
151
+
152
+ def _download_vis(self):
153
+
154
+ storage_path = self.config.build_info.get(self.data_type).storage
155
+ storage_path = utils.get_cache_path(storage_path)
156
+
157
+ if not os.path.exists(storage_path):
158
+ warnings.warn(
159
+ f"""
160
+ The specified path {storage_path} for visual inputs does not exist.
161
+ Please provide a correct path to the visual inputs or
162
+ refer to datasets/download_scripts/README.md for downloading instructions.
163
+ """
164
+ )
165
+
166
+ def build(self):
167
+ """
168
+ Create by split datasets inheriting torch.utils.data.Datasets.
169
+
170
+ # build() can be dataset-specific. Overwrite to customize.
171
+ """
172
+ self.build_processors()
173
+
174
+ build_info = self.config.build_info
175
+
176
+ ann_info = build_info.annotations
177
+ vis_info = build_info.get(self.data_type)
178
+
179
+ datasets = dict()
180
+ for split in ann_info.keys():
181
+ if split not in ["train", "val", "test"]:
182
+ continue
183
+
184
+ is_train = split == "train"
185
+
186
+ # processors
187
+ vis_processor = (
188
+ self.vis_processors["train"]
189
+ if is_train
190
+ else self.vis_processors["eval"]
191
+ )
192
+ text_processor = (
193
+ self.text_processors["train"]
194
+ if is_train
195
+ else self.text_processors["eval"]
196
+ )
197
+
198
+ # annotation path
199
+ ann_paths = ann_info.get(split).storage
200
+ if isinstance(ann_paths, str):
201
+ ann_paths = [ann_paths]
202
+
203
+ abs_ann_paths = []
204
+ for ann_path in ann_paths:
205
+ if not os.path.isabs(ann_path):
206
+ ann_path = utils.get_cache_path(ann_path)
207
+ abs_ann_paths.append(ann_path)
208
+ ann_paths = abs_ann_paths
209
+
210
+ # visual data storage path
211
+ vis_path = os.path.join(vis_info.storage, split)
212
+
213
+ if not os.path.isabs(vis_path):
214
+ # vis_path = os.path.join(utils.get_cache_path(), vis_path)
215
+ vis_path = utils.get_cache_path(vis_path)
216
+
217
+ if not os.path.exists(vis_path):
218
+ warnings.warn("storage path {} does not exist.".format(vis_path))
219
+
220
+ # create datasets
221
+ dataset_cls = self.train_dataset_cls if is_train else self.eval_dataset_cls
222
+ datasets[split] = dataset_cls(
223
+ vis_processor=vis_processor,
224
+ text_processor=text_processor,
225
+ ann_paths=ann_paths,
226
+ vis_root=vis_path,
227
+ )
228
+
229
+ return datasets
230
+
231
+
232
+ def load_dataset_config(cfg_path):
233
+ cfg = OmegaConf.load(cfg_path).datasets
234
+ cfg = cfg[list(cfg.keys())[0]]
235
+
236
+ return cfg
global_local/datasets/builders/image_text_pair_builder.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+ import warnings
4
+
5
+ from global_local.common.registry import registry
6
+ from global_local.datasets.builders.base_dataset_builder import BaseDatasetBuilder
7
+ from global_local.datasets.datasets.laion_dataset import LaionDataset
8
+ from global_local.datasets.datasets.cc_sbu_dataset import CCSBUDataset, CCSBUAlignDataset
9
+
10
+
11
+ @registry.register_builder("cc_sbu")
12
+ class CCSBUBuilder(BaseDatasetBuilder):
13
+ train_dataset_cls = CCSBUDataset
14
+
15
+ DATASET_CONFIG_DICT = {"default": "configs/datasets/cc_sbu/defaults.yaml"}
16
+
17
+ def _download_ann(self):
18
+ pass
19
+
20
+ def _download_vis(self):
21
+ pass
22
+
23
+ def build(self):
24
+ self.build_processors()
25
+
26
+ build_info = self.config.build_info
27
+
28
+ datasets = dict()
29
+ split = "train"
30
+
31
+ # create datasets
32
+ # [NOTE] return inner_datasets (wds.DataPipeline)
33
+ dataset_cls = self.train_dataset_cls
34
+ datasets[split] = dataset_cls(
35
+ vis_processor=self.vis_processors[split],
36
+ text_processor=self.text_processors[split],
37
+ location=build_info.storage,
38
+ ).inner_dataset
39
+
40
+ return datasets
41
+
42
+
43
+ @registry.register_builder("laion")
44
+ class LaionBuilder(BaseDatasetBuilder):
45
+ train_dataset_cls = LaionDataset
46
+
47
+ DATASET_CONFIG_DICT = {"default": "configs/datasets/laion/defaults.yaml"}
48
+
49
+ def _download_ann(self):
50
+ pass
51
+
52
+ def _download_vis(self):
53
+ pass
54
+
55
+ def build(self):
56
+ self.build_processors()
57
+
58
+ build_info = self.config.build_info
59
+
60
+ datasets = dict()
61
+ split = "train"
62
+
63
+ # create datasets
64
+ # [NOTE] return inner_datasets (wds.DataPipeline)
65
+ dataset_cls = self.train_dataset_cls
66
+ datasets[split] = dataset_cls(
67
+ vis_processor=self.vis_processors[split],
68
+ text_processor=self.text_processors[split],
69
+ location=build_info.storage,
70
+ ).inner_dataset
71
+
72
+ return datasets
73
+
74
+
75
+ @registry.register_builder("cc_sbu_align")
76
+ class CCSBUAlignBuilder(BaseDatasetBuilder):
77
+ train_dataset_cls = CCSBUAlignDataset
78
+
79
+ DATASET_CONFIG_DICT = {
80
+ "default": "configs/datasets/cc_sbu/align.yaml",
81
+ }
82
+
83
+ def build_datasets(self):
84
+ # at this point, all the annotations and image/videos should be all downloaded to the specified locations.
85
+ logging.info("Building datasets...")
86
+ self.build_processors()
87
+
88
+ build_info = self.config.build_info
89
+ storage_path = build_info.storage
90
+
91
+ datasets = dict()
92
+
93
+ if not os.path.exists(storage_path):
94
+ warnings.warn("storage path {} does not exist.".format(storage_path))
95
+
96
+ # create datasets
97
+ dataset_cls = self.train_dataset_cls
98
+ datasets['train'] = dataset_cls(
99
+ vis_processor=self.vis_processors["train"],
100
+ text_processor=self.text_processors["train"],
101
+ ann_paths=[os.path.join(storage_path, 'filter_cap.json')],
102
+ vis_root=os.path.join(storage_path, 'image'),
103
+ )
104
+
105
+ return datasets
106
+
global_local/datasets/builders/instruct_builder.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+ import warnings
4
+
5
+ from global_local.common.registry import registry
6
+ from global_local.datasets.builders.base_dataset_builder import BaseDatasetBuilder
7
+ from global_local.datasets.datasets.laion_dataset import LaionDataset
8
+ from global_local.datasets.datasets.llava_instruct_dataset import Instruct_Dataset
9
+ from global_local.datasets.datasets.video_instruct_dataset import Video_Instruct_Dataset
10
+
11
+ @registry.register_builder("instruct")
12
+ class Instruct_Builder(BaseDatasetBuilder):
13
+ train_dataset_cls = Instruct_Dataset
14
+
15
+ DATASET_CONFIG_DICT = {"default": "configs/datasets/instruct/defaults.yaml"}
16
+
17
+ def _download_ann(self):
18
+ pass
19
+
20
+ def _download_vis(self):
21
+ pass
22
+
23
+ def build(self):
24
+ self.build_processors()
25
+ datasets = dict()
26
+ split = "train"
27
+
28
+ build_info = self.config.build_info
29
+ dataset_cls = self.train_dataset_cls
30
+ if self.config.num_video_query_token:
31
+ num_video_query_token = self.config.num_video_query_token
32
+ else:
33
+ num_video_query_token = 32
34
+
35
+ if self.config.tokenizer_name:
36
+ tokenizer_name = self.config.tokenizer_name
37
+ else:
38
+ tokenizer_name = '/mnt/workspace/ckpt/vicuna-13b/'
39
+
40
+
41
+ datasets[split] = dataset_cls(
42
+ vis_processor=self.vis_processors[split],
43
+ text_processor=self.text_processors[split],
44
+ vis_root=build_info.videos_dir,
45
+ ann_root=build_info.anno_dir,
46
+ num_video_query_token = num_video_query_token,
47
+ tokenizer_name = tokenizer_name,
48
+ data_type = self.config.data_type
49
+ )
50
+
51
+ return datasets
52
+
53
+ @registry.register_builder("webvid_instruct")
54
+ class WebvidInstruct_Builder(Instruct_Builder):
55
+ train_dataset_cls = Video_Instruct_Dataset
56
+
57
+ DATASET_CONFIG_DICT = {
58
+ "default": "configs/datasets/instruct/webvid_instruct.yaml",
59
+ }
60
+
61
+ @registry.register_builder("webvid_instruct_zh")
62
+ class WebvidInstruct_zh_Builder(Instruct_Builder):
63
+ train_dataset_cls = Video_Instruct_Dataset
64
+
65
+ DATASET_CONFIG_DICT = {
66
+ "default": "configs/datasets/instruct/webvid_instruct.yaml",
67
+ }
68
+
69
+
70
+
71
+ @registry.register_builder("llava_instruct")
72
+ class LlavaInstruct_Builder(Instruct_Builder):
73
+ train_dataset_cls = Instruct_Dataset
74
+
75
+ DATASET_CONFIG_DICT = {
76
+ "default": "configs/datasets/instruct/llava_instruct.yaml",
77
+ }
78
+
global_local/datasets/builders/video_caption_builder.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+ import warnings
4
+
5
+ from global_local.common.registry import registry
6
+ from global_local.datasets.builders.base_dataset_builder import BaseDatasetBuilder
7
+ from global_local.datasets.datasets.webvid_datasets import WebvidDataset
8
+
9
+ @registry.register_builder("webvid")
10
+ class WebvidBuilder(BaseDatasetBuilder):
11
+ train_dataset_cls = WebvidDataset
12
+ DATASET_CONFIG_DICT = {"default": "configs/datasets/webvid/defaults.yaml"}
13
+
14
+ def _download_ann(self):
15
+ pass
16
+
17
+ def _download_vis(self):
18
+ pass
19
+
20
+ def build(self):
21
+ self.build_processors()
22
+ datasets = dict()
23
+ split = "train"
24
+
25
+ build_info = self.config.build_info
26
+ dataset_cls = self.train_dataset_cls
27
+ datasets[split] = dataset_cls(
28
+ vis_processor=self.vis_processors[split],
29
+ text_processor=self.text_processors[split],
30
+ vis_root=build_info.videos_dir,
31
+ ann_root=build_info.anno_dir
32
+ )
33
+
34
+ return datasets
global_local/datasets/data_utils.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import gzip
9
+ import logging
10
+ import os
11
+ import random as rnd
12
+ import tarfile
13
+ import zipfile
14
+ import random
15
+ from typing import List
16
+ from tqdm import tqdm
17
+
18
+ import decord
19
+ from decord import VideoReader
20
+ import webdataset as wds
21
+ import numpy as np
22
+ import torch
23
+ from torch.utils.data.dataset import IterableDataset
24
+
25
+ from global_local.common.registry import registry
26
+ from global_local.datasets.datasets.base_dataset import ConcatDataset
27
+
28
+
29
+ decord.bridge.set_bridge("torch")
30
+ MAX_INT = registry.get("MAX_INT")
31
+
32
+
33
+ class ChainDataset(wds.DataPipeline):
34
+ r"""Dataset for chaining multiple :class:`DataPipeline` s.
35
+
36
+ This class is useful to assemble different existing dataset streams. The
37
+ chaining operation is done on-the-fly, so concatenating large-scale
38
+ datasets with this class will be efficient.
39
+
40
+ Args:
41
+ datasets (iterable of IterableDataset): datasets to be chained together
42
+ """
43
+ def __init__(self, datasets: List[wds.DataPipeline]) -> None:
44
+ super().__init__()
45
+ self.datasets = datasets
46
+ self.prob = []
47
+ self.names = []
48
+ for dataset in self.datasets:
49
+ if hasattr(dataset, 'name'):
50
+ self.names.append(dataset.name)
51
+ else:
52
+ self.names.append('Unknown')
53
+ if hasattr(dataset, 'sample_ratio'):
54
+ self.prob.append(dataset.sample_ratio)
55
+ else:
56
+ self.prob.append(1)
57
+ logging.info("One of the datapipeline doesn't define ratio and set to 1 automatically.")
58
+
59
+ def __iter__(self):
60
+ datastreams = [iter(dataset) for dataset in self.datasets]
61
+ while True:
62
+ select_datastream = random.choices(datastreams, weights=self.prob, k=1)[0]
63
+ yield next(select_datastream)
64
+
65
+
66
+ def apply_to_sample(f, sample):
67
+ if len(sample) == 0:
68
+ return {}
69
+
70
+ def _apply(x):
71
+ if torch.is_tensor(x):
72
+ return f(x)
73
+ elif isinstance(x, dict):
74
+ return {key: _apply(value) for key, value in x.items()}
75
+ elif isinstance(x, list):
76
+ return [_apply(x) for x in x]
77
+ else:
78
+ return x
79
+
80
+ return _apply(sample)
81
+
82
+
83
+ def move_to_cuda(sample):
84
+ def _move_to_cuda(tensor):
85
+ return tensor.cuda()
86
+
87
+ return apply_to_sample(_move_to_cuda, sample)
88
+
89
+
90
+ def prepare_sample(samples, cuda_enabled=True):
91
+ if cuda_enabled:
92
+ samples = move_to_cuda(samples)
93
+
94
+ # TODO fp16 support
95
+
96
+ return samples
97
+
98
+
99
+ def reorg_datasets_by_split(datasets):
100
+ """
101
+ Organizes datasets by split.
102
+
103
+ Args:
104
+ datasets: dict of torch.utils.data.Dataset objects by name.
105
+
106
+ Returns:
107
+ Dict of datasets by split {split_name: List[Datasets]}.
108
+ """
109
+ # if len(datasets) == 1:
110
+ # return datasets[list(datasets.keys())[0]]
111
+ # else:
112
+ reorg_datasets = dict()
113
+
114
+ # reorganize by split
115
+ for _, dataset in datasets.items():
116
+ for split_name, dataset_split in dataset.items():
117
+ if split_name not in reorg_datasets:
118
+ reorg_datasets[split_name] = [dataset_split]
119
+ else:
120
+ reorg_datasets[split_name].append(dataset_split)
121
+
122
+ return reorg_datasets
123
+
124
+
125
+ def concat_datasets(datasets):
126
+ """
127
+ Concatenates multiple datasets into a single dataset.
128
+
129
+ It supports may-style datasets and DataPipeline from WebDataset. Currently, does not support
130
+ generic IterableDataset because it requires creating separate samplers.
131
+
132
+ Now only supports conctenating training datasets and assuming validation and testing
133
+ have only a single dataset. This is because metrics should not be computed on the concatenated
134
+ datasets.
135
+
136
+ Args:
137
+ datasets: dict of torch.utils.data.Dataset objects by split.
138
+
139
+ Returns:
140
+ Dict of concatenated datasets by split, "train" is the concatenation of multiple datasets,
141
+ "val" and "test" remain the same.
142
+
143
+ If the input training datasets contain both map-style and DataPipeline datasets, returns
144
+ a tuple, where the first element is a concatenated map-style dataset and the second
145
+ element is a chained DataPipeline dataset.
146
+
147
+ """
148
+ # concatenate datasets in the same split
149
+ for split_name in datasets:
150
+ if split_name != "train":
151
+ assert (
152
+ len(datasets[split_name]) == 1
153
+ ), "Do not support multiple {} datasets.".format(split_name)
154
+ datasets[split_name] = datasets[split_name][0]
155
+ else:
156
+ iterable_datasets, map_datasets = [], []
157
+ for dataset in datasets[split_name]:
158
+ if isinstance(dataset, wds.DataPipeline):
159
+ logging.info(
160
+ "Dataset {} is IterableDataset, can't be concatenated.".format(
161
+ dataset
162
+ )
163
+ )
164
+ iterable_datasets.append(dataset)
165
+ elif isinstance(dataset, IterableDataset):
166
+ raise NotImplementedError(
167
+ "Do not support concatenation of generic IterableDataset."
168
+ )
169
+ else:
170
+ map_datasets.append(dataset)
171
+
172
+ # if len(iterable_datasets) > 0:
173
+ # concatenate map-style datasets and iterable-style datasets separately
174
+ if len(iterable_datasets) > 1:
175
+ chained_datasets = (
176
+ ChainDataset(iterable_datasets)
177
+ )
178
+ elif len(iterable_datasets) == 1:
179
+ chained_datasets = iterable_datasets[0]
180
+ else:
181
+ chained_datasets = None
182
+
183
+ concat_datasets = (
184
+ ConcatDataset(map_datasets) if len(map_datasets) > 0 else None
185
+ )
186
+
187
+ train_datasets = concat_datasets, chained_datasets
188
+ train_datasets = tuple([x for x in train_datasets if x is not None])
189
+ train_datasets = (
190
+ train_datasets[0] if len(train_datasets) == 1 else train_datasets
191
+ )
192
+
193
+ datasets[split_name] = train_datasets
194
+
195
+ return datasets
196
+
global_local/datasets/datasets/__init__.py ADDED
File without changes