wchai commited on
Commit
1bbd0a2
1 Parent(s): 7cb12e9

Upload 3_8b_v/xtuner_config.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. 3_8b_v/xtuner_config.py +303 -0
3_8b_v/xtuner_config.py ADDED
@@ -0,0 +1,303 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from mmengine.dataset import DefaultSampler
3
+ from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
4
+ LoggerHook, ParamSchedulerHook)
5
+
6
+ from transformers import (AutoModelForCausalLM, AutoTokenizer,
7
+ BitsAndBytesConfig,
8
+ CLIPImageProcessor, CLIPVisionModel,
9
+ SiglipVisionModel, SiglipImageProcessor, AutoProcessor)
10
+ from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
11
+
12
+ from peft import LoraConfig
13
+ from torch.optim import AdamW
14
+ from xtuner.dataset import LLaVADataset, CambrianDataset, ConcatDataset
15
+ from xtuner.dataset.collate_fns import default_collate_fn
16
+ from xtuner.dataset.map_fns import llava_map_fn, cambrian_map_fn, template_map_fn_factory
17
+ from xtuner.dataset.samplers import LengthGroupedSampler
18
+ from xtuner.engine import DatasetInfoHook, EvaluateChatHook
19
+ from xtuner.model import LLaVAModel, PikaModel
20
+ from xtuner.utils import PROMPT_TEMPLATE
21
+
22
+
23
+ #######################################################################
24
+ # PART 1 Settings #
25
+ #######################################################################
26
+ # Model
27
+ llm_name_or_path = 'meta-llama/Meta-Llama-3-8B-Instruct'
28
+ visual_encoder_name_or_path = 'google/siglip-so400m-patch14-384'
29
+ pretrained_pth = '/data/wenhao/projects/xtuner/work_dirs/final_new_p/projector'
30
+
31
+ prompt_template = PROMPT_TEMPLATE.llama3_chat
32
+ max_length = 4096
33
+ size = 378
34
+ batch_size = 1 # per_device
35
+ accumulative_counts = 32
36
+ lr = 4e-5
37
+ dataloader_num_workers = 0
38
+ max_epochs = 1
39
+ optim_type = AdamW
40
+ betas = (0.9, 0.999)
41
+ weight_decay = 0
42
+ max_norm = 1 # grad clip
43
+ warmup_ratio = 0.03
44
+ sf = False
45
+
46
+ # Save
47
+ save_steps = 200
48
+ save_total_limit = 2 # Maximum checkpoints to keep (-1 means unlimited)
49
+
50
+ #######################################################################
51
+ # PART 2 Model & Tokenizer & Image Processor #
52
+ #######################################################################
53
+ tokenizer = dict(
54
+ type=AutoTokenizer.from_pretrained,
55
+ pretrained_model_name_or_path=llm_name_or_path,
56
+ trust_remote_code=True,
57
+ padding_side='right')
58
+
59
+ image_processor = dict(
60
+ type=CLIPImageProcessor.from_pretrained,
61
+ pretrained_model_name_or_path='laion/CLIP-ViT-bigG-14-laion2B-39B-b160k',
62
+ trust_remote_code=True,
63
+ size=size,
64
+ crop_size=size)
65
+
66
+ model = dict(
67
+ type=PikaModel,
68
+ sf=sf,
69
+ freeze_llm=True,
70
+ freeze_visual_encoder=False,
71
+ pretrained_pth=pretrained_pth,
72
+ llm=dict(
73
+ type=AutoModelForCausalLM.from_pretrained,
74
+ pretrained_model_name_or_path=llm_name_or_path,
75
+ trust_remote_code=True,
76
+ torch_dtype=torch.float16,),
77
+ visual_encoder=dict(
78
+ type=SiglipVisionModel.from_pretrained,
79
+ pretrained_model_name_or_path=visual_encoder_name_or_path))
80
+
81
+ #######################################################################
82
+ # PART 3 Dataset & Dataloader #
83
+ #######################################################################
84
+ m3it_data_root = '/data/wenhao/projects/xtuner/data/m3it/'
85
+ m3it_data_path = m3it_data_root + 'm3it.jsonl'
86
+ m3it_image_folder = m3it_data_root
87
+ m3it_dataset = dict(
88
+ type=CambrianDataset,
89
+ offline_processed_text_folder='/data/wenhao/projects/xtuner/data/m3it/pre_token_llama31',
90
+ image_folder=m3it_image_folder,
91
+ image_processor=image_processor,
92
+ dataset_map_fn=cambrian_map_fn,
93
+ template_map_fn=dict(
94
+ type=template_map_fn_factory, template=prompt_template),
95
+ max_length=max_length,
96
+ pad_image_to_square=True)
97
+
98
+
99
+ chatterbox_data_root = '/data/wenhao/projects/xtuner/data/ChatterBox/'
100
+ chatterbox_data_path = chatterbox_data_root + 'chatterbox_76k.jsonl'
101
+ chatterbox_image_folder = chatterbox_data_root
102
+ chatterbox_dataset = dict(
103
+ type=CambrianDataset,
104
+ offline_processed_text_folder='/data/wenhao/projects/xtuner/data/ChatterBox/pre_token_llama31',
105
+ image_folder=chatterbox_image_folder,
106
+ image_processor=image_processor,
107
+ dataset_map_fn=cambrian_map_fn,
108
+ template_map_fn=dict(
109
+ type=template_map_fn_factory, template=prompt_template),
110
+ max_length=max_length,
111
+ pad_image_to_square=True)
112
+
113
+
114
+ laion_data_root = '/data/wenhao/projects/xtuner/data/LLaVA-Pretrain/'
115
+ laion_data_path = laion_data_root + 'laion_558k.jsonl'
116
+ laion_image_folder = laion_data_root
117
+ laion_dataset = dict(
118
+ type=CambrianDataset,
119
+ offline_processed_text_folder='/data/wenhao/projects/xtuner/data/LLaVA-Pretrain/pre_token_llama31',
120
+ image_folder=laion_image_folder,
121
+ image_processor=image_processor,
122
+ dataset_map_fn=cambrian_map_fn,
123
+ template_map_fn=dict(
124
+ type=template_map_fn_factory, template=prompt_template),
125
+ max_length=max_length,
126
+ pad_image_to_square=True)
127
+
128
+ face_data_root = '/data/wenhao/projects/xtuner/data/FaceCaption-15M/'
129
+ face_data_path = face_data_root + 'FaceCaption-100K.jsonl'
130
+ face_image_folder = face_data_root + 'full_data'
131
+ face_processed_text_folder = face_data_root + 'pre_token_llama3'
132
+ face_dataset = dict(
133
+ type=CambrianDataset,
134
+ offline_processed_text_folder=face_processed_text_folder,
135
+ image_folder=face_image_folder,
136
+ image_processor=image_processor,
137
+ dataset_map_fn=cambrian_map_fn,
138
+ template_map_fn=dict(
139
+ type=template_map_fn_factory, template=prompt_template),
140
+ max_length=max_length,
141
+ pad_image_to_square=True)
142
+
143
+ cost_data_root = '/data/wenhao/projects/xtuner/data/COST/'
144
+ cost_data_path = cost_data_root + 'cost.jsonl'
145
+ cost_image_folder = cost_data_root
146
+ cost_dataset = dict(
147
+ type=CambrianDataset,
148
+ offline_processed_text_folder='/data/wenhao/projects/xtuner/data/COST/pre_token_llama31',
149
+ # tokenizer=tokenizer,
150
+ # data_path='/data/wenhao/projects/xtuner/data/COST/cost.jsonl',
151
+ image_folder=cost_image_folder,
152
+ image_processor=image_processor,
153
+ dataset_map_fn=cambrian_map_fn,
154
+ template_map_fn=dict(
155
+ type=template_map_fn_factory, template=prompt_template),
156
+ max_length=max_length,
157
+ pad_image_to_square=True)
158
+
159
+ sharept_data_root = '/data/wenhao/projects/xtuner/data/ShareGPT4V/'
160
+ sharept_data_path = sharept_data_root + 'sharegpt4v_pt.jsonl'
161
+ sharept_image_folder = '/data/wenhao/projects/xtuner/data/'
162
+ sharept_dataset = dict(
163
+ type=CambrianDataset,
164
+ offline_processed_text_folder='/data/wenhao/projects/xtuner/data/ShareGPT4V/pre_token_llama31',
165
+ # tokenizer=tokenizer,
166
+ # data_path='/data/wenhao/projects/xtuner/data/ShareGPT4V/sharegpt4v_pt.jsonl',
167
+ image_folder=sharept_image_folder,
168
+ image_processor=image_processor,
169
+ dataset_map_fn=cambrian_map_fn,
170
+ template_map_fn=dict(
171
+ type=template_map_fn_factory, template=prompt_template),
172
+ max_length=max_length,
173
+ pad_image_to_square=True)
174
+
175
+ llavaone_data_root = '/data/wenhao/projects/xtuner/data/onevision/'
176
+ llavaone_data_path = '/data/wenhao/projects/xtuner/data/LLaVA-OneVision-Data/llava_onevision.jsonl'
177
+ llavaone_image_folder = llavaone_data_root + 'images'
178
+ llavaone_dataset = dict(
179
+ type=CambrianDataset,
180
+ offline_processed_text_folder='/data/wenhao/projects/xtuner/data/onevision/pre_token_llama31',
181
+ # tokenizer=tokenizer,
182
+ # data_path='/data/wenhao/projects/xtuner/data/LLaVA-OneVision-Data/llava_onevision.jsonl',
183
+ image_folder=llavaone_image_folder,
184
+ image_processor=image_processor,
185
+ dataset_map_fn=cambrian_map_fn,
186
+ template_map_fn=dict(
187
+ type=template_map_fn_factory, template=prompt_template),
188
+ max_length=max_length,
189
+ pad_image_to_square=True)
190
+
191
+ train_dataset = dict(
192
+ type=ConcatDataset,
193
+ datasets=[m3it_dataset, chatterbox_dataset, laion_dataset, face_dataset, cost_dataset, sharept_dataset, llavaone_dataset],
194
+ )
195
+
196
+ train_dataloader = dict(
197
+ batch_size=batch_size,
198
+ num_workers=dataloader_num_workers,
199
+ dataset=train_dataset,
200
+ sampler=dict(type=DefaultSampler, shuffle=True),
201
+ collate_fn=dict(type=default_collate_fn))
202
+
203
+ #######################################################################
204
+ # PART 4 Scheduler & Optimizer #
205
+ #######################################################################
206
+ # optimizer
207
+ optim_wrapper = dict(
208
+ type=AmpOptimWrapper,
209
+ optimizer=dict(
210
+ type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
211
+ clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
212
+ accumulative_counts=accumulative_counts,
213
+ loss_scale='dynamic',
214
+ dtype='float16')
215
+
216
+ # learning policy
217
+ # More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501
218
+ param_scheduler = [
219
+ dict(
220
+ type=LinearLR,
221
+ start_factor=1e-5,
222
+ by_epoch=True,
223
+ begin=0,
224
+ end=warmup_ratio * max_epochs,
225
+ convert_to_iter_based=True),
226
+ dict(
227
+ type=CosineAnnealingLR,
228
+ eta_min=0.0,
229
+ by_epoch=True,
230
+ begin=warmup_ratio * max_epochs,
231
+ T_max=max_epochs,
232
+ convert_to_iter_based=True)
233
+ ]
234
+
235
+ # train, val, test setting
236
+ train_cfg = dict(by_epoch=True, max_epochs=max_epochs, val_interval=1)
237
+
238
+ #######################################################################
239
+ # PART 5 Runtime #
240
+ #######################################################################
241
+ # Evaluate the generation performance during the training
242
+ evaluation_freq = 100
243
+ SYSTEM = ''
244
+ evaluation_images = 'https://llava-vl.github.io/static/images/view.jpg'
245
+ evaluation_inputs = ['请描述一下这张照片', 'Please describe this picture']
246
+
247
+
248
+ # Log the dialogue periodically during the training process, optional
249
+ custom_hooks = [
250
+ dict(type=DatasetInfoHook, tokenizer=tokenizer),
251
+ dict(
252
+ type=EvaluateChatHook,
253
+ tokenizer=tokenizer,
254
+ image_processor=image_processor,
255
+ every_n_iters=evaluation_freq,
256
+ evaluation_inputs=evaluation_inputs,
257
+ evaluation_images=evaluation_images,
258
+ system=SYSTEM,
259
+ prompt_template=prompt_template)
260
+ ]
261
+
262
+ # configure default hooks
263
+ default_hooks = dict(
264
+ # record the time of every iteration.
265
+ timer=dict(type=IterTimerHook),
266
+ # print log every 100 iterations.
267
+ logger=dict(type=LoggerHook, interval=10),
268
+ # enable the parameter scheduler.
269
+ param_scheduler=dict(type=ParamSchedulerHook),
270
+ # save checkpoint per epoch.
271
+ checkpoint=dict(
272
+ type=CheckpointHook,
273
+ by_epoch=False,
274
+ interval=save_steps,
275
+ max_keep_ckpts=save_total_limit),
276
+ # set sampler seed in distributed evrionment.
277
+ sampler_seed=dict(type=DistSamplerSeedHook),
278
+ )
279
+
280
+ # configure environment
281
+ env_cfg = dict(
282
+ # whether to enable cudnn benchmark
283
+ cudnn_benchmark=False,
284
+ # set multi process parameters
285
+ mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
286
+ # set distributed parameters
287
+ dist_cfg=dict(backend='nccl'),
288
+ )
289
+
290
+ # set visualizer
291
+ visualizer = None
292
+
293
+ # set log level
294
+ log_level = 'INFO'
295
+
296
+ # load from which checkpoint
297
+ load_from = None
298
+
299
+ # whether to resume training from the loaded checkpoint
300
+ resume = False
301
+
302
+ # Defaults to use random seed and disable `deterministic`
303
+ randomness = dict(seed=None, deterministic=False)