StarCycle commited on
Commit
f265c2a
1 Parent(s): 11247e2

Upload 2 files

Browse files
Files changed (2) hide show
  1. finetune.py +233 -0
  2. pretrain.py +214 -0
finetune.py ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import torch
3
+ from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
4
+ LoggerHook, ParamSchedulerHook)
5
+ from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
6
+ from peft import LoraConfig
7
+ from torch.optim import AdamW
8
+ from transformers import (AutoModelForCausalLM, AutoTokenizer,
9
+ BitsAndBytesConfig, SiglipImageProcessor,
10
+ SiglipVisionModel)
11
+
12
+ from xtuner.dataset import LLaVADataset
13
+ from xtuner.dataset.collate_fns import default_collate_fn
14
+ from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory
15
+ from xtuner.dataset.samplers import LengthGroupedSampler
16
+ from xtuner.engine.hooks import DatasetInfoHook, EvaluateChatHook
17
+ from xtuner.engine.runner import TrainLoop
18
+ from xtuner.model import LLaVAModel
19
+ from xtuner.utils import PROMPT_TEMPLATE
20
+
21
+ #######################################################################
22
+ # PART 1 Settings #
23
+ #######################################################################
24
+ # Model
25
+ llm_name_or_path = 'internlm/internlm2-chat-1_8b'
26
+ visual_encoder_name_or_path = 'google/siglip-so400m-patch14-384'
27
+ # Specify the pretrained pth
28
+ pretrained_pth = './work_dirs/pretrain/iter_8721.pth'
29
+
30
+ # Data
31
+ data_root = './'
32
+ data_path = data_root + 'LLaVA-Instruct-150K/llava_v1_5_mix665k.json'
33
+ image_folder = data_root + 'llava_images'
34
+ prompt_template = PROMPT_TEMPLATE.internlm2_chat
35
+ max_length = int(2048 - (336 / 14)**2)
36
+
37
+ # Scheduler & Optimizer
38
+ batch_size = 4 # per_device
39
+ accumulative_counts = 8
40
+ dataloader_num_workers = 4
41
+ prefetch = 5
42
+ max_epochs = 1
43
+ optim_type = AdamW
44
+ lr = 2e-4
45
+ betas = (0.9, 0.999)
46
+ weight_decay = 0
47
+ max_norm = 1 # grad clip
48
+ warmup_ratio = 0.03
49
+
50
+ # Save
51
+ save_steps = 500
52
+ save_total_limit = 2 # Maximum checkpoints to keep (-1 means unlimited)
53
+
54
+ # Evaluate the generation performance during the training
55
+ evaluation_freq = 500
56
+ SYSTEM = ''
57
+ evaluation_images = 'https://llava-vl.github.io/static/images/view.jpg'
58
+ evaluation_inputs = ['请描述一下这张照片', 'Please describe this picture']
59
+
60
+ #######################################################################
61
+ # PART 2 Model & Tokenizer & Image Processor #
62
+ #######################################################################
63
+ tokenizer = dict(
64
+ type=AutoTokenizer.from_pretrained,
65
+ pretrained_model_name_or_path=llm_name_or_path,
66
+ trust_remote_code=True,
67
+ padding_side='right')
68
+
69
+ image_processor = dict(
70
+ type=SiglipImageProcessor.from_pretrained,
71
+ pretrained_model_name_or_path=visual_encoder_name_or_path,
72
+ trust_remote_code=True)
73
+
74
+ model = dict(
75
+ type=LLaVAModel,
76
+ freeze_llm=True,
77
+ freeze_visual_encoder=True,
78
+ pretrained_pth=pretrained_pth,
79
+ llm=dict(
80
+ type=AutoModelForCausalLM.from_pretrained,
81
+ pretrained_model_name_or_path=llm_name_or_path,
82
+ trust_remote_code=True,
83
+ torch_dtype=torch.float16,
84
+ quantization_config=dict(
85
+ type=BitsAndBytesConfig,
86
+ load_in_4bit=True,
87
+ load_in_8bit=False,
88
+ llm_int8_threshold=6.0,
89
+ llm_int8_has_fp16_weight=False,
90
+ bnb_4bit_compute_dtype=torch.float16,
91
+ bnb_4bit_use_double_quant=True,
92
+ bnb_4bit_quant_type='nf4')),
93
+ llm_lora=dict(
94
+ type=LoraConfig,
95
+ r=512,
96
+ lora_alpha=256,
97
+ lora_dropout=0.05,
98
+ bias='none',
99
+ task_type='CAUSAL_LM'),
100
+ visual_encoder=dict(
101
+ type=SiglipVisionModel.from_pretrained,
102
+ pretrained_model_name_or_path=visual_encoder_name_or_path),
103
+ visual_encoder_lora=dict(
104
+ type=LoraConfig, r=64, lora_alpha=16, lora_dropout=0.05, bias='none'),
105
+ )
106
+
107
+ #######################################################################
108
+ # PART 3 Dataset & Dataloader #
109
+ #######################################################################
110
+ llava_dataset = dict(
111
+ type=LLaVADataset,
112
+ data_path=data_path,
113
+ image_folder=image_folder,
114
+ tokenizer=tokenizer,
115
+ image_processor=image_processor,
116
+ dataset_map_fn=llava_map_fn,
117
+ template_map_fn=dict(
118
+ type=template_map_fn_factory, template=prompt_template),
119
+ max_length=max_length,
120
+ pad_image_to_square=True)
121
+
122
+ train_dataloader = dict(
123
+ batch_size=batch_size,
124
+ num_workers=dataloader_num_workers,
125
+ prefetch_factor=prefetch,
126
+ dataset=llava_dataset,
127
+ sampler=dict(
128
+ type=LengthGroupedSampler,
129
+ length_property='modality_length',
130
+ per_device_batch_size=batch_size * accumulative_counts),
131
+ collate_fn=dict(type=default_collate_fn))
132
+
133
+ #######################################################################
134
+ # PART 4 Scheduler & Optimizer #
135
+ #######################################################################
136
+ # optimizer
137
+ optim_wrapper = dict(
138
+ type=AmpOptimWrapper,
139
+ optimizer=dict(
140
+ type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
141
+ clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
142
+ accumulative_counts=accumulative_counts,
143
+ loss_scale='dynamic',
144
+ dtype='float16')
145
+
146
+ # learning policy
147
+ # More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501
148
+ param_scheduler = [
149
+ dict(
150
+ type=LinearLR,
151
+ start_factor=1e-5,
152
+ by_epoch=True,
153
+ begin=0,
154
+ end=warmup_ratio * max_epochs,
155
+ convert_to_iter_based=True),
156
+ dict(
157
+ type=CosineAnnealingLR,
158
+ eta_min=0.0,
159
+ by_epoch=True,
160
+ begin=warmup_ratio * max_epochs,
161
+ end=max_epochs,
162
+ convert_to_iter_based=True)
163
+ ]
164
+
165
+ # train, val, test setting
166
+ train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
167
+
168
+ #######################################################################
169
+ # PART 5 Runtime #
170
+ #######################################################################
171
+ # Log the dialogue periodically during the training process, optional
172
+ custom_hooks = [
173
+ dict(type=DatasetInfoHook, tokenizer=tokenizer),
174
+ dict(
175
+ type=EvaluateChatHook,
176
+ tokenizer=tokenizer,
177
+ image_processor=image_processor,
178
+ every_n_iters=evaluation_freq,
179
+ evaluation_inputs=evaluation_inputs,
180
+ evaluation_images=evaluation_images,
181
+ system=SYSTEM,
182
+ prompt_template=prompt_template)
183
+ ]
184
+
185
+ # configure default hooks
186
+ default_hooks = dict(
187
+ # record the time of every iteration.
188
+ timer=dict(type=IterTimerHook),
189
+ # print log every 10 iterations.
190
+ logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
191
+ # enable the parameter scheduler.
192
+ param_scheduler=dict(type=ParamSchedulerHook),
193
+ # save checkpoint per `save_steps`.
194
+ checkpoint=dict(
195
+ type=CheckpointHook,
196
+ by_epoch=False,
197
+ interval=save_steps,
198
+ max_keep_ckpts=save_total_limit),
199
+ # set sampler seed in distributed evrionment.
200
+ sampler_seed=dict(type=DistSamplerSeedHook),
201
+ )
202
+
203
+ # configure environment
204
+ env_cfg = dict(
205
+ # whether to enable cudnn benchmark
206
+ cudnn_benchmark=False,
207
+ # set multi process parameters
208
+ mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
209
+ # set distributed parameters
210
+ dist_cfg=dict(backend='nccl'),
211
+ )
212
+
213
+ # set visualizer
214
+ from mmengine.visualization import Visualizer, TensorboardVisBackend
215
+ visualizer = dict(
216
+ type=Visualizer,
217
+ vis_backends=[dict(type=TensorboardVisBackend)]
218
+ )
219
+
220
+ # set log level
221
+ log_level = 'INFO'
222
+
223
+ # load from which checkpoint
224
+ load_from = None
225
+
226
+ # whether to resume training from the loaded checkpoint
227
+ resume = False
228
+
229
+ # Defaults to use random seed and disable `deterministic`
230
+ randomness = dict(seed=None, deterministic=False)
231
+
232
+ # set log processor
233
+ log_processor = dict(by_epoch=False)
pretrain.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import torch
3
+ from mmengine.dataset import DefaultSampler
4
+ from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
5
+ LoggerHook, ParamSchedulerHook)
6
+ from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
7
+ from torch.optim import AdamW
8
+ from transformers import (AutoModelForCausalLM, AutoTokenizer,
9
+ BitsAndBytesConfig, SiglipImageProcessor,
10
+ SiglipVisionModel)
11
+
12
+ from xtuner.dataset import LLaVADataset
13
+ from xtuner.dataset.collate_fns import default_collate_fn
14
+ from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory
15
+ from xtuner.engine.hooks import DatasetInfoHook, EvaluateChatHook
16
+ from xtuner.engine.runner import TrainLoop
17
+ from xtuner.model import LLaVAModel
18
+ from xtuner.utils import PROMPT_TEMPLATE
19
+
20
+ #######################################################################
21
+ # PART 1 Settings #
22
+ #######################################################################
23
+ # Model
24
+ llm_name_or_path = 'internlm/internlm2-chat-1_8b'
25
+ visual_encoder_name_or_path = 'google/siglip-so400m-patch14-384'
26
+
27
+ # Data
28
+ data_root = './'
29
+ data_path = data_root + 'LLaVA-Pretrain/blip_laion_cc_sbu_558k.json'
30
+ image_folder = data_root + 'LLaVA-Pretrain/images'
31
+ prompt_template = PROMPT_TEMPLATE.internlm2_chat
32
+ max_length = int(2048 - (336 / 14)**2)
33
+
34
+ # Scheduler & Optimizer
35
+ batch_size = 16 # per_device
36
+ accumulative_counts = 4
37
+ dataloader_num_workers = 16
38
+ max_epochs = 1
39
+ optim_type = AdamW
40
+ lr = 1e-3
41
+ betas = (0.9, 0.999)
42
+ weight_decay = 0
43
+ max_norm = 1 # grad clip
44
+ warmup_ratio = 0.03
45
+
46
+ # Save
47
+ save_steps = 2000
48
+ save_total_limit = 2 # Maximum checkpoints to keep (-1 means unlimited)
49
+
50
+ # Evaluate the generation performance during the training
51
+ evaluation_freq = 2000
52
+ SYSTEM = ''
53
+ evaluation_images = 'https://llava-vl.github.io/static/images/view.jpg'
54
+ evaluation_inputs = ['请描述一下这张照片', 'Please describe this picture']
55
+
56
+ #######################################################################
57
+ # PART 2 Model & Tokenizer & Image Processor #
58
+ #######################################################################
59
+ tokenizer = dict(
60
+ type=AutoTokenizer.from_pretrained,
61
+ pretrained_model_name_or_path=llm_name_or_path,
62
+ trust_remote_code=True,
63
+ padding_side='right')
64
+
65
+ image_processor = dict(
66
+ type=SiglipImageProcessor.from_pretrained,
67
+ pretrained_model_name_or_path=visual_encoder_name_or_path,
68
+ trust_remote_code=True)
69
+
70
+ model = dict(
71
+ type=LLaVAModel,
72
+ freeze_llm=True,
73
+ freeze_visual_encoder=True,
74
+ llm=dict(
75
+ type=AutoModelForCausalLM.from_pretrained,
76
+ pretrained_model_name_or_path=llm_name_or_path,
77
+ trust_remote_code=True,
78
+ torch_dtype=torch.float16,
79
+ quantization_config=dict(
80
+ type=BitsAndBytesConfig,
81
+ load_in_4bit=True,
82
+ load_in_8bit=False,
83
+ llm_int8_threshold=6.0,
84
+ llm_int8_has_fp16_weight=False,
85
+ bnb_4bit_compute_dtype=torch.float16,
86
+ bnb_4bit_use_double_quant=True,
87
+ bnb_4bit_quant_type='nf4')),
88
+ visual_encoder=dict(
89
+ type=SiglipVisionModel.from_pretrained,
90
+ pretrained_model_name_or_path=visual_encoder_name_or_path))
91
+
92
+ #######################################################################
93
+ # PART 3 Dataset & Dataloader #
94
+ #######################################################################
95
+ llava_dataset = dict(
96
+ type=LLaVADataset,
97
+ data_path=data_path,
98
+ image_folder=image_folder,
99
+ tokenizer=tokenizer,
100
+ image_processor=image_processor,
101
+ dataset_map_fn=llava_map_fn,
102
+ template_map_fn=dict(
103
+ type=template_map_fn_factory, template=prompt_template),
104
+ max_length=max_length,
105
+ pad_image_to_square=False)
106
+
107
+ train_dataloader = dict(
108
+ batch_size=batch_size,
109
+ num_workers=dataloader_num_workers,
110
+ dataset=llava_dataset,
111
+ sampler=dict(type=DefaultSampler, shuffle=True),
112
+ collate_fn=dict(type=default_collate_fn))
113
+
114
+ #######################################################################
115
+ # PART 4 Scheduler & Optimizer #
116
+ #######################################################################
117
+ # optimizer
118
+ optim_wrapper = dict(
119
+ type=AmpOptimWrapper,
120
+ optimizer=dict(
121
+ type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
122
+ clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
123
+ accumulative_counts=accumulative_counts,
124
+ loss_scale='dynamic',
125
+ dtype='float16')
126
+
127
+ # learning policy
128
+ # More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501
129
+ param_scheduler = [
130
+ dict(
131
+ type=LinearLR,
132
+ start_factor=1e-5,
133
+ by_epoch=True,
134
+ begin=0,
135
+ end=warmup_ratio * max_epochs,
136
+ convert_to_iter_based=True),
137
+ dict(
138
+ type=CosineAnnealingLR,
139
+ eta_min=0.0,
140
+ by_epoch=True,
141
+ begin=warmup_ratio * max_epochs,
142
+ end=max_epochs,
143
+ convert_to_iter_based=True)
144
+ ]
145
+
146
+ # train, val, test setting
147
+ train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
148
+
149
+ #######################################################################
150
+ # PART 5 Runtime #
151
+ #######################################################################
152
+ # Log the dialogue periodically during the training process, optional
153
+ custom_hooks = [
154
+ dict(type=DatasetInfoHook, tokenizer=tokenizer),
155
+ dict(
156
+ type=EvaluateChatHook,
157
+ tokenizer=tokenizer,
158
+ image_processor=image_processor,
159
+ every_n_iters=evaluation_freq,
160
+ evaluation_inputs=evaluation_inputs,
161
+ evaluation_images=evaluation_images,
162
+ system=SYSTEM,
163
+ prompt_template=prompt_template)
164
+ ]
165
+
166
+ # configure default hooks
167
+ default_hooks = dict(
168
+ # record the time of every iteration.
169
+ timer=dict(type=IterTimerHook),
170
+ # print log every 10 iterations.
171
+ logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
172
+ # enable the parameter scheduler.
173
+ param_scheduler=dict(type=ParamSchedulerHook),
174
+ # save checkpoint per `save_steps`.
175
+ checkpoint=dict(
176
+ type=CheckpointHook,
177
+ by_epoch=False,
178
+ interval=save_steps,
179
+ max_keep_ckpts=save_total_limit),
180
+ # set sampler seed in distributed evrionment.
181
+ sampler_seed=dict(type=DistSamplerSeedHook),
182
+ )
183
+
184
+ # configure environment
185
+ env_cfg = dict(
186
+ # whether to enable cudnn benchmark
187
+ cudnn_benchmark=False,
188
+ # set multi process parameters
189
+ mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
190
+ # set distributed parameters
191
+ dist_cfg=dict(backend='nccl'),
192
+ )
193
+
194
+ # set visualizer
195
+ from mmengine.visualization import Visualizer, TensorboardVisBackend
196
+ visualizer = dict(
197
+ type=Visualizer,
198
+ vis_backends=[dict(type=TensorboardVisBackend)]
199
+ )
200
+
201
+ # set log level
202
+ log_level = 'INFO'
203
+
204
+ # load from which checkpoint
205
+ load_from = None
206
+
207
+ # whether to resume training from the loaded checkpoint
208
+ resume = False
209
+
210
+ # Defaults to use random seed and disable `deterministic`
211
+ randomness = dict(seed=None, deterministic=False)
212
+
213
+ # set log processor
214
+ log_processor = dict(by_epoch=False)