StarCycle commited on
Commit
377d3d1
·
1 Parent(s): d4d1f38
Files changed (31) hide show
  1. finetune.py +233 -0
  2. lora_and_projector/llm_adapter/README.md +202 -0
  3. lora_and_projector/llm_adapter/adapter_config.json +32 -0
  4. lora_and_projector/llm_adapter/adapter_model.safetensors +3 -0
  5. lora_and_projector/projector/config.json +17 -0
  6. lora_and_projector/projector/configuration_projector.py +23 -0
  7. lora_and_projector/projector/model.safetensors +3 -0
  8. lora_and_projector/projector/modeling_projector.py +51 -0
  9. lora_and_projector/visual_encoder_adapter/README.md +202 -0
  10. lora_and_projector/visual_encoder_adapter/adapter_config.json +35 -0
  11. lora_and_projector/visual_encoder_adapter/adapter_model.safetensors +3 -0
  12. lora_and_projector/xtuner_config.py +222 -0
  13. mmbench_results/20240310_025701/args.json +19 -0
  14. mmbench_results/20240310_025701/mmbench_result.json +9 -0
  15. mmbench_results/20240310_025701/mmbench_result.xlsx +0 -0
  16. mmbench_results/20240310_030410/args.json +19 -0
  17. mmbench_results/20240310_030410/mmbench_result.xlsx +0 -0
  18. mmbench_results/20240310_031346/args.json +19 -0
  19. mmbench_results/20240310_031346/mmbench_result.json +9 -0
  20. mmbench_results/20240310_031346/mmbench_result.xlsx +0 -0
  21. mmbench_results/20240310_032208/args.json +19 -0
  22. mmbench_results/20240310_032208/mmbench_result.xlsx +0 -0
  23. mmbench_results/20240310_033150/args.json +19 -0
  24. mmbench_results/20240310_033150/mmbench_result.json +10 -0
  25. mmbench_results/20240310_033150/mmbench_result.xlsx +0 -0
  26. modified_transformers/src/transformers/models/siglip/modeling_siglip.py +1299 -0
  27. modified_xtuner/xtuner/dataset/huggingface.py +316 -0
  28. modified_xtuner/xtuner/dataset/llava.py +88 -0
  29. modified_xtuner/xtuner/tools/chat.py +491 -0
  30. modified_xtuner/xtuner/tools/mmbench.py +510 -0
  31. pretrain.py +214 -0
finetune.py ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import torch
3
+ from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
4
+ LoggerHook, ParamSchedulerHook)
5
+ from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
6
+ from peft import LoraConfig
7
+ from torch.optim import AdamW
8
+ from transformers import (AutoModelForCausalLM, AutoTokenizer,
9
+ BitsAndBytesConfig, SiglipImageProcessor,
10
+ SiglipVisionModel)
11
+
12
+ from xtuner.dataset import LLaVADataset
13
+ from xtuner.dataset.collate_fns import default_collate_fn
14
+ from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory
15
+ from xtuner.dataset.samplers import LengthGroupedSampler
16
+ from xtuner.engine.hooks import DatasetInfoHook, EvaluateChatHook
17
+ from xtuner.engine.runner import TrainLoop
18
+ from xtuner.model import LLaVAModel
19
+ from xtuner.utils import PROMPT_TEMPLATE
20
+
21
+ #######################################################################
22
+ # PART 1 Settings #
23
+ #######################################################################
24
+ # Model
25
+ llm_name_or_path = 'internlm/internlm2-1_8b'
26
+ visual_encoder_name_or_path = 'google/siglip-so400m-patch14-384'
27
+ # Specify the pretrained pth
28
+ pretrained_pth = './work_dirs/pretrain/iter_11628.pth'
29
+
30
+ # Data
31
+ data_root = './'
32
+ data_path = data_root + 'LLaVA-Instruct-150K/llava_v1_5_mix665k.json'
33
+ image_folder = data_root + 'llava_images'
34
+ prompt_template = PROMPT_TEMPLATE.internlm2_chat
35
+ max_length = int(2048 - (336 / 14)**2)
36
+
37
+ # Scheduler & Optimizer
38
+ batch_size = 4 # per_device
39
+ accumulative_counts = 8
40
+ dataloader_num_workers = 4
41
+ prefetch = 5
42
+ max_epochs = 1
43
+ optim_type = AdamW
44
+ lr = 2e-4
45
+ betas = (0.9, 0.999)
46
+ weight_decay = 0
47
+ max_norm = 1 # grad clip
48
+ warmup_ratio = 0.03
49
+
50
+ # Save
51
+ save_steps = 500
52
+ save_total_limit = 2 # Maximum checkpoints to keep (-1 means unlimited)
53
+
54
+ # Evaluate the generation performance during the training
55
+ evaluation_freq = 500
56
+ SYSTEM = ''
57
+ evaluation_images = 'https://llava-vl.github.io/static/images/view.jpg'
58
+ evaluation_inputs = ['请描述一下这张照片', 'Please describe this picture']
59
+
60
+ #######################################################################
61
+ # PART 2 Model & Tokenizer & Image Processor #
62
+ #######################################################################
63
+ tokenizer = dict(
64
+ type=AutoTokenizer.from_pretrained,
65
+ pretrained_model_name_or_path=llm_name_or_path,
66
+ trust_remote_code=True,
67
+ padding_side='right')
68
+
69
+ image_processor = dict(
70
+ type=SiglipImageProcessor.from_pretrained,
71
+ pretrained_model_name_or_path=visual_encoder_name_or_path,
72
+ trust_remote_code=True)
73
+
74
+ model = dict(
75
+ type=LLaVAModel,
76
+ freeze_llm=True,
77
+ freeze_visual_encoder=True,
78
+ pretrained_pth=pretrained_pth,
79
+ llm=dict(
80
+ type=AutoModelForCausalLM.from_pretrained,
81
+ pretrained_model_name_or_path=llm_name_or_path,
82
+ trust_remote_code=True,
83
+ torch_dtype=torch.float16,
84
+ quantization_config=dict(
85
+ type=BitsAndBytesConfig,
86
+ load_in_4bit=True,
87
+ load_in_8bit=False,
88
+ llm_int8_threshold=6.0,
89
+ llm_int8_has_fp16_weight=False,
90
+ bnb_4bit_compute_dtype=torch.float16,
91
+ bnb_4bit_use_double_quant=True,
92
+ bnb_4bit_quant_type='nf4')),
93
+ llm_lora=dict(
94
+ type=LoraConfig,
95
+ r=512,
96
+ lora_alpha=256,
97
+ lora_dropout=0.05,
98
+ bias='none',
99
+ task_type='CAUSAL_LM'),
100
+ visual_encoder=dict(
101
+ type=SiglipVisionModel.from_pretrained,
102
+ pretrained_model_name_or_path=visual_encoder_name_or_path),
103
+ visual_encoder_lora=dict(
104
+ type=LoraConfig, r=64, lora_alpha=16, lora_dropout=0.05, bias='none'),
105
+ )
106
+
107
+ #######################################################################
108
+ # PART 3 Dataset & Dataloader #
109
+ #######################################################################
110
+ llava_dataset = dict(
111
+ type=LLaVADataset,
112
+ data_path=data_path,
113
+ image_folder=image_folder,
114
+ tokenizer=tokenizer,
115
+ image_processor=image_processor,
116
+ dataset_map_fn=llava_map_fn,
117
+ template_map_fn=dict(
118
+ type=template_map_fn_factory, template=prompt_template),
119
+ max_length=max_length,
120
+ pad_image_to_square=True)
121
+
122
+ train_dataloader = dict(
123
+ batch_size=batch_size,
124
+ num_workers=dataloader_num_workers,
125
+ prefetch_factor=prefetch,
126
+ dataset=llava_dataset,
127
+ sampler=dict(
128
+ type=LengthGroupedSampler,
129
+ length_property='modality_length',
130
+ per_device_batch_size=batch_size * accumulative_counts),
131
+ collate_fn=dict(type=default_collate_fn))
132
+
133
+ #######################################################################
134
+ # PART 4 Scheduler & Optimizer #
135
+ #######################################################################
136
+ # optimizer
137
+ optim_wrapper = dict(
138
+ type=AmpOptimWrapper,
139
+ optimizer=dict(
140
+ type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
141
+ clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
142
+ accumulative_counts=accumulative_counts,
143
+ loss_scale='dynamic',
144
+ dtype='float16')
145
+
146
+ # learning policy
147
+ # More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501
148
+ param_scheduler = [
149
+ dict(
150
+ type=LinearLR,
151
+ start_factor=1e-5,
152
+ by_epoch=True,
153
+ begin=0,
154
+ end=warmup_ratio * max_epochs,
155
+ convert_to_iter_based=True),
156
+ dict(
157
+ type=CosineAnnealingLR,
158
+ eta_min=0.0,
159
+ by_epoch=True,
160
+ begin=warmup_ratio * max_epochs,
161
+ end=max_epochs,
162
+ convert_to_iter_based=True)
163
+ ]
164
+
165
+ # train, val, test setting
166
+ train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
167
+
168
+ #######################################################################
169
+ # PART 5 Runtime #
170
+ #######################################################################
171
+ # Log the dialogue periodically during the training process, optional
172
+ custom_hooks = [
173
+ dict(type=DatasetInfoHook, tokenizer=tokenizer),
174
+ dict(
175
+ type=EvaluateChatHook,
176
+ tokenizer=tokenizer,
177
+ image_processor=image_processor,
178
+ every_n_iters=evaluation_freq,
179
+ evaluation_inputs=evaluation_inputs,
180
+ evaluation_images=evaluation_images,
181
+ system=SYSTEM,
182
+ prompt_template=prompt_template)
183
+ ]
184
+
185
+ # configure default hooks
186
+ default_hooks = dict(
187
+ # record the time of every iteration.
188
+ timer=dict(type=IterTimerHook),
189
+ # print log every 10 iterations.
190
+ logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
191
+ # enable the parameter scheduler.
192
+ param_scheduler=dict(type=ParamSchedulerHook),
193
+ # save checkpoint per `save_steps`.
194
+ checkpoint=dict(
195
+ type=CheckpointHook,
196
+ by_epoch=False,
197
+ interval=save_steps,
198
+ max_keep_ckpts=save_total_limit),
199
+ # set sampler seed in distributed evrionment.
200
+ sampler_seed=dict(type=DistSamplerSeedHook),
201
+ )
202
+
203
+ # configure environment
204
+ env_cfg = dict(
205
+ # whether to enable cudnn benchmark
206
+ cudnn_benchmark=False,
207
+ # set multi process parameters
208
+ mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
209
+ # set distributed parameters
210
+ dist_cfg=dict(backend='nccl'),
211
+ )
212
+
213
+ # set visualizer
214
+ from mmengine.visualization import Visualizer, TensorboardVisBackend
215
+ visualizer = dict(
216
+ type=Visualizer,
217
+ vis_backends=[dict(type=TensorboardVisBackend)]
218
+ )
219
+
220
+ # set log level
221
+ log_level = 'INFO'
222
+
223
+ # load from which checkpoint
224
+ load_from = None
225
+
226
+ # whether to resume training from the loaded checkpoint
227
+ resume = False
228
+
229
+ # Defaults to use random seed and disable `deterministic`
230
+ randomness = dict(seed=None, deterministic=False)
231
+
232
+ # set log processor
233
+ log_processor = dict(by_epoch=False)
lora_and_projector/llm_adapter/README.md ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: peft
3
+ base_model: internlm/internlm2-1_8b
4
+ ---
5
+
6
+ # Model Card for Model ID
7
+
8
+ <!-- Provide a quick summary of what the model is/does. -->
9
+
10
+
11
+
12
+ ## Model Details
13
+
14
+ ### Model Description
15
+
16
+ <!-- Provide a longer summary of what this model is. -->
17
+
18
+
19
+
20
+ - **Developed by:** [More Information Needed]
21
+ - **Funded by [optional]:** [More Information Needed]
22
+ - **Shared by [optional]:** [More Information Needed]
23
+ - **Model type:** [More Information Needed]
24
+ - **Language(s) (NLP):** [More Information Needed]
25
+ - **License:** [More Information Needed]
26
+ - **Finetuned from model [optional]:** [More Information Needed]
27
+
28
+ ### Model Sources [optional]
29
+
30
+ <!-- Provide the basic links for the model. -->
31
+
32
+ - **Repository:** [More Information Needed]
33
+ - **Paper [optional]:** [More Information Needed]
34
+ - **Demo [optional]:** [More Information Needed]
35
+
36
+ ## Uses
37
+
38
+ <!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
39
+
40
+ ### Direct Use
41
+
42
+ <!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
43
+
44
+ [More Information Needed]
45
+
46
+ ### Downstream Use [optional]
47
+
48
+ <!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
49
+
50
+ [More Information Needed]
51
+
52
+ ### Out-of-Scope Use
53
+
54
+ <!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
55
+
56
+ [More Information Needed]
57
+
58
+ ## Bias, Risks, and Limitations
59
+
60
+ <!-- This section is meant to convey both technical and sociotechnical limitations. -->
61
+
62
+ [More Information Needed]
63
+
64
+ ### Recommendations
65
+
66
+ <!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
67
+
68
+ Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
69
+
70
+ ## How to Get Started with the Model
71
+
72
+ Use the code below to get started with the model.
73
+
74
+ [More Information Needed]
75
+
76
+ ## Training Details
77
+
78
+ ### Training Data
79
+
80
+ <!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
81
+
82
+ [More Information Needed]
83
+
84
+ ### Training Procedure
85
+
86
+ <!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
87
+
88
+ #### Preprocessing [optional]
89
+
90
+ [More Information Needed]
91
+
92
+
93
+ #### Training Hyperparameters
94
+
95
+ - **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
96
+
97
+ #### Speeds, Sizes, Times [optional]
98
+
99
+ <!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
100
+
101
+ [More Information Needed]
102
+
103
+ ## Evaluation
104
+
105
+ <!-- This section describes the evaluation protocols and provides the results. -->
106
+
107
+ ### Testing Data, Factors & Metrics
108
+
109
+ #### Testing Data
110
+
111
+ <!-- This should link to a Dataset Card if possible. -->
112
+
113
+ [More Information Needed]
114
+
115
+ #### Factors
116
+
117
+ <!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
118
+
119
+ [More Information Needed]
120
+
121
+ #### Metrics
122
+
123
+ <!-- These are the evaluation metrics being used, ideally with a description of why. -->
124
+
125
+ [More Information Needed]
126
+
127
+ ### Results
128
+
129
+ [More Information Needed]
130
+
131
+ #### Summary
132
+
133
+
134
+
135
+ ## Model Examination [optional]
136
+
137
+ <!-- Relevant interpretability work for the model goes here -->
138
+
139
+ [More Information Needed]
140
+
141
+ ## Environmental Impact
142
+
143
+ <!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
144
+
145
+ Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
146
+
147
+ - **Hardware Type:** [More Information Needed]
148
+ - **Hours used:** [More Information Needed]
149
+ - **Cloud Provider:** [More Information Needed]
150
+ - **Compute Region:** [More Information Needed]
151
+ - **Carbon Emitted:** [More Information Needed]
152
+
153
+ ## Technical Specifications [optional]
154
+
155
+ ### Model Architecture and Objective
156
+
157
+ [More Information Needed]
158
+
159
+ ### Compute Infrastructure
160
+
161
+ [More Information Needed]
162
+
163
+ #### Hardware
164
+
165
+ [More Information Needed]
166
+
167
+ #### Software
168
+
169
+ [More Information Needed]
170
+
171
+ ## Citation [optional]
172
+
173
+ <!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
174
+
175
+ **BibTeX:**
176
+
177
+ [More Information Needed]
178
+
179
+ **APA:**
180
+
181
+ [More Information Needed]
182
+
183
+ ## Glossary [optional]
184
+
185
+ <!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
186
+
187
+ [More Information Needed]
188
+
189
+ ## More Information [optional]
190
+
191
+ [More Information Needed]
192
+
193
+ ## Model Card Authors [optional]
194
+
195
+ [More Information Needed]
196
+
197
+ ## Model Card Contact
198
+
199
+ [More Information Needed]
200
+ ### Framework versions
201
+
202
+ - PEFT 0.9.1.dev0
lora_and_projector/llm_adapter/adapter_config.json ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "alpha_pattern": {},
3
+ "auto_mapping": null,
4
+ "base_model_name_or_path": "internlm/internlm2-1_8b",
5
+ "bias": "none",
6
+ "fan_in_fan_out": false,
7
+ "inference_mode": true,
8
+ "init_lora_weights": true,
9
+ "layers_pattern": null,
10
+ "layers_to_transform": null,
11
+ "loftq_config": {},
12
+ "lora_alpha": 256,
13
+ "lora_dropout": 0.05,
14
+ "megatron_config": null,
15
+ "megatron_core": "megatron.core",
16
+ "modules_to_save": null,
17
+ "peft_type": "LORA",
18
+ "r": 512,
19
+ "rank_pattern": {},
20
+ "revision": null,
21
+ "target_modules": [
22
+ "w2",
23
+ "wo",
24
+ "wqkv",
25
+ "output",
26
+ "w1",
27
+ "w3"
28
+ ],
29
+ "task_type": "CAUSAL_LM",
30
+ "use_dora": false,
31
+ "use_rslora": false
32
+ }
lora_and_projector/llm_adapter/adapter_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:df04221be2e75431ea2069d3815fb07f2d551770ceb8e4278aaebcfc537a896c
3
+ size 1103527968
lora_and_projector/projector/config.json ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "ProjectorModel"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_projector.ProjectorConfig",
7
+ "AutoModel": "modeling_projector.ProjectorModel"
8
+ },
9
+ "bias": true,
10
+ "depth": 2,
11
+ "hidden_act": "gelu",
12
+ "llm_hidden_size": 2048,
13
+ "model_type": "projector",
14
+ "torch_dtype": "float32",
15
+ "transformers_version": "4.39.0.dev0",
16
+ "visual_hidden_size": 1152
17
+ }
lora_and_projector/projector/configuration_projector.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from transformers import PretrainedConfig
3
+
4
+
5
+ class ProjectorConfig(PretrainedConfig):
6
+ model_type = 'projector'
7
+ _auto_class = 'AutoConfig'
8
+
9
+ def __init__(
10
+ self,
11
+ visual_hidden_size=4096,
12
+ llm_hidden_size=4096,
13
+ depth=2,
14
+ hidden_act='gelu',
15
+ bias=True,
16
+ **kwargs,
17
+ ):
18
+ self.visual_hidden_size = visual_hidden_size
19
+ self.llm_hidden_size = llm_hidden_size
20
+ self.depth = depth
21
+ self.hidden_act = hidden_act
22
+ self.bias = bias
23
+ super().__init__(**kwargs)
lora_and_projector/projector/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:106abdd4e4391aede50872561811949888780705ce46bcd30b7dd45fa022acdd
3
+ size 26231144
lora_and_projector/projector/modeling_projector.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import torch
3
+ import torch.nn as nn
4
+ from transformers import PreTrainedModel
5
+ from transformers.activations import ACT2FN
6
+
7
+ from .configuration_projector import ProjectorConfig
8
+
9
+
10
+ class ProjectorModel(PreTrainedModel):
11
+ _auto_class = 'AutoModel'
12
+ config_class = ProjectorConfig
13
+ base_model_prefix = 'model'
14
+ supports_gradient_checkpointing = True
15
+
16
+ def __init__(self, config: ProjectorConfig) -> None:
17
+ super().__init__(config)
18
+ self.gradient_checkpointing = False
19
+
20
+ modules = [
21
+ nn.Linear(
22
+ config.visual_hidden_size,
23
+ config.llm_hidden_size,
24
+ bias=config.bias)
25
+ ]
26
+ for _ in range(1, config.depth):
27
+ modules.append(ACT2FN[config.hidden_act])
28
+ modules.append(
29
+ nn.Linear(
30
+ config.llm_hidden_size,
31
+ config.llm_hidden_size,
32
+ bias=config.bias))
33
+ self.model = nn.Sequential(*modules)
34
+
35
+ def enable_input_require_grads(self):
36
+
37
+ def make_inputs_require_grad(module, input, output):
38
+ output.requires_grad_(True)
39
+
40
+ self.model.register_forward_hook(make_inputs_require_grad)
41
+
42
+ def _set_gradient_checkpointing(self, module, value=False):
43
+ if isinstance(module, ProjectorModel):
44
+ module.gradient_checkpointing = value
45
+
46
+ def forward(self, x):
47
+ if self.gradient_checkpointing and self.training:
48
+ layer_outputs = torch.utils.checkpoint.checkpoint(self.model, x)
49
+ else:
50
+ layer_outputs = self.model(x)
51
+ return layer_outputs
lora_and_projector/visual_encoder_adapter/README.md ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: peft
3
+ base_model: google/siglip-so400m-patch14-384
4
+ ---
5
+
6
+ # Model Card for Model ID
7
+
8
+ <!-- Provide a quick summary of what the model is/does. -->
9
+
10
+
11
+
12
+ ## Model Details
13
+
14
+ ### Model Description
15
+
16
+ <!-- Provide a longer summary of what this model is. -->
17
+
18
+
19
+
20
+ - **Developed by:** [More Information Needed]
21
+ - **Funded by [optional]:** [More Information Needed]
22
+ - **Shared by [optional]:** [More Information Needed]
23
+ - **Model type:** [More Information Needed]
24
+ - **Language(s) (NLP):** [More Information Needed]
25
+ - **License:** [More Information Needed]
26
+ - **Finetuned from model [optional]:** [More Information Needed]
27
+
28
+ ### Model Sources [optional]
29
+
30
+ <!-- Provide the basic links for the model. -->
31
+
32
+ - **Repository:** [More Information Needed]
33
+ - **Paper [optional]:** [More Information Needed]
34
+ - **Demo [optional]:** [More Information Needed]
35
+
36
+ ## Uses
37
+
38
+ <!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
39
+
40
+ ### Direct Use
41
+
42
+ <!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
43
+
44
+ [More Information Needed]
45
+
46
+ ### Downstream Use [optional]
47
+
48
+ <!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
49
+
50
+ [More Information Needed]
51
+
52
+ ### Out-of-Scope Use
53
+
54
+ <!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
55
+
56
+ [More Information Needed]
57
+
58
+ ## Bias, Risks, and Limitations
59
+
60
+ <!-- This section is meant to convey both technical and sociotechnical limitations. -->
61
+
62
+ [More Information Needed]
63
+
64
+ ### Recommendations
65
+
66
+ <!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
67
+
68
+ Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
69
+
70
+ ## How to Get Started with the Model
71
+
72
+ Use the code below to get started with the model.
73
+
74
+ [More Information Needed]
75
+
76
+ ## Training Details
77
+
78
+ ### Training Data
79
+
80
+ <!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
81
+
82
+ [More Information Needed]
83
+
84
+ ### Training Procedure
85
+
86
+ <!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
87
+
88
+ #### Preprocessing [optional]
89
+
90
+ [More Information Needed]
91
+
92
+
93
+ #### Training Hyperparameters
94
+
95
+ - **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
96
+
97
+ #### Speeds, Sizes, Times [optional]
98
+
99
+ <!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
100
+
101
+ [More Information Needed]
102
+
103
+ ## Evaluation
104
+
105
+ <!-- This section describes the evaluation protocols and provides the results. -->
106
+
107
+ ### Testing Data, Factors & Metrics
108
+
109
+ #### Testing Data
110
+
111
+ <!-- This should link to a Dataset Card if possible. -->
112
+
113
+ [More Information Needed]
114
+
115
+ #### Factors
116
+
117
+ <!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
118
+
119
+ [More Information Needed]
120
+
121
+ #### Metrics
122
+
123
+ <!-- These are the evaluation metrics being used, ideally with a description of why. -->
124
+
125
+ [More Information Needed]
126
+
127
+ ### Results
128
+
129
+ [More Information Needed]
130
+
131
+ #### Summary
132
+
133
+
134
+
135
+ ## Model Examination [optional]
136
+
137
+ <!-- Relevant interpretability work for the model goes here -->
138
+
139
+ [More Information Needed]
140
+
141
+ ## Environmental Impact
142
+
143
+ <!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
144
+
145
+ Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
146
+
147
+ - **Hardware Type:** [More Information Needed]
148
+ - **Hours used:** [More Information Needed]
149
+ - **Cloud Provider:** [More Information Needed]
150
+ - **Compute Region:** [More Information Needed]
151
+ - **Carbon Emitted:** [More Information Needed]
152
+
153
+ ## Technical Specifications [optional]
154
+
155
+ ### Model Architecture and Objective
156
+
157
+ [More Information Needed]
158
+
159
+ ### Compute Infrastructure
160
+
161
+ [More Information Needed]
162
+
163
+ #### Hardware
164
+
165
+ [More Information Needed]
166
+
167
+ #### Software
168
+
169
+ [More Information Needed]
170
+
171
+ ## Citation [optional]
172
+
173
+ <!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
174
+
175
+ **BibTeX:**
176
+
177
+ [More Information Needed]
178
+
179
+ **APA:**
180
+
181
+ [More Information Needed]
182
+
183
+ ## Glossary [optional]
184
+
185
+ <!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
186
+
187
+ [More Information Needed]
188
+
189
+ ## More Information [optional]
190
+
191
+ [More Information Needed]
192
+
193
+ ## Model Card Authors [optional]
194
+
195
+ [More Information Needed]
196
+
197
+ ## Model Card Contact
198
+
199
+ [More Information Needed]
200
+ ### Framework versions
201
+
202
+ - PEFT 0.9.1.dev0
lora_and_projector/visual_encoder_adapter/adapter_config.json ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "alpha_pattern": {},
3
+ "auto_mapping": {
4
+ "base_model_class": "SiglipVisionModel",
5
+ "parent_library": "transformers.models.siglip.modeling_siglip"
6
+ },
7
+ "base_model_name_or_path": "google/siglip-so400m-patch14-384",
8
+ "bias": "none",
9
+ "fan_in_fan_out": false,
10
+ "inference_mode": true,
11
+ "init_lora_weights": true,
12
+ "layers_pattern": null,
13
+ "layers_to_transform": null,
14
+ "loftq_config": {},
15
+ "lora_alpha": 16,
16
+ "lora_dropout": 0.05,
17
+ "megatron_config": null,
18
+ "megatron_core": "megatron.core",
19
+ "modules_to_save": null,
20
+ "peft_type": "LORA",
21
+ "r": 64,
22
+ "rank_pattern": {},
23
+ "revision": null,
24
+ "target_modules": [
25
+ "k_proj",
26
+ "fc2",
27
+ "v_proj",
28
+ "fc1",
29
+ "q_proj",
30
+ "out_proj"
31
+ ],
32
+ "task_type": null,
33
+ "use_dora": false,
34
+ "use_rslora": false
35
+ }
lora_and_projector/visual_encoder_adapter/adapter_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9d62d7122066f9d37d97c2c3288a125ebc088b2731678593d9716d55fad7e230
3
+ size 142556624
lora_and_projector/xtuner_config.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ SYSTEM = ''
2
+ accumulative_counts = 8
3
+ batch_size = 4
4
+ betas = (
5
+ 0.9,
6
+ 0.999,
7
+ )
8
+ custom_hooks = [
9
+ dict(
10
+ tokenizer=dict(
11
+ padding_side='right',
12
+ pretrained_model_name_or_path='internlm/internlm2-1_8b',
13
+ trust_remote_code=True,
14
+ type='transformers.AutoTokenizer.from_pretrained'),
15
+ type='xtuner.engine.hooks.DatasetInfoHook'),
16
+ dict(
17
+ evaluation_images='https://llava-vl.github.io/static/images/view.jpg',
18
+ evaluation_inputs=[
19
+ '请描述一下这张照片',
20
+ 'Please describe this picture',
21
+ ],
22
+ every_n_iters=500,
23
+ image_processor=dict(
24
+ pretrained_model_name_or_path='google/siglip-so400m-patch14-384',
25
+ trust_remote_code=True,
26
+ type='transformers.SiglipImageProcessor.from_pretrained'),
27
+ prompt_template='xtuner.utils.PROMPT_TEMPLATE.internlm2_chat',
28
+ system='',
29
+ tokenizer=dict(
30
+ padding_side='right',
31
+ pretrained_model_name_or_path='internlm/internlm2-1_8b',
32
+ trust_remote_code=True,
33
+ type='transformers.AutoTokenizer.from_pretrained'),
34
+ type='xtuner.engine.hooks.EvaluateChatHook'),
35
+ ]
36
+ data_path = './LLaVA-Instruct-150K/llava_v1_5_mix665k.json'
37
+ data_root = './'
38
+ dataloader_num_workers = 4
39
+ default_hooks = dict(
40
+ checkpoint=dict(
41
+ by_epoch=False,
42
+ interval=500,
43
+ max_keep_ckpts=2,
44
+ type='mmengine.hooks.CheckpointHook'),
45
+ logger=dict(
46
+ interval=10,
47
+ log_metric_by_epoch=False,
48
+ type='mmengine.hooks.LoggerHook'),
49
+ param_scheduler=dict(type='mmengine.hooks.ParamSchedulerHook'),
50
+ sampler_seed=dict(type='mmengine.hooks.DistSamplerSeedHook'),
51
+ timer=dict(type='mmengine.hooks.IterTimerHook'))
52
+ env_cfg = dict(
53
+ cudnn_benchmark=False,
54
+ dist_cfg=dict(backend='nccl'),
55
+ mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0))
56
+ evaluation_freq = 500
57
+ evaluation_images = 'https://llava-vl.github.io/static/images/view.jpg'
58
+ evaluation_inputs = [
59
+ '请描述一下这张照片',
60
+ 'Please describe this picture',
61
+ ]
62
+ image_folder = './llava_images'
63
+ image_processor = dict(
64
+ pretrained_model_name_or_path='google/siglip-so400m-patch14-384',
65
+ trust_remote_code=True,
66
+ type='transformers.SiglipImageProcessor.from_pretrained')
67
+ launcher = 'pytorch'
68
+ llava_dataset = dict(
69
+ data_path='./LLaVA-Instruct-150K/llava_v1_5_mix665k.json',
70
+ dataset_map_fn='xtuner.dataset.map_fns.llava_map_fn',
71
+ image_folder='./llava_images',
72
+ image_processor=dict(
73
+ pretrained_model_name_or_path='google/siglip-so400m-patch14-384',
74
+ trust_remote_code=True,
75
+ type='transformers.SiglipImageProcessor.from_pretrained'),
76
+ max_length=1472,
77
+ pad_image_to_square=True,
78
+ template_map_fn=dict(
79
+ template='xtuner.utils.PROMPT_TEMPLATE.internlm2_chat',
80
+ type='xtuner.dataset.map_fns.template_map_fn_factory'),
81
+ tokenizer=dict(
82
+ padding_side='right',
83
+ pretrained_model_name_or_path='internlm/internlm2-1_8b',
84
+ trust_remote_code=True,
85
+ type='transformers.AutoTokenizer.from_pretrained'),
86
+ type='xtuner.dataset.LLaVADataset')
87
+ llm_name_or_path = 'internlm/internlm2-1_8b'
88
+ load_from = None
89
+ log_level = 'INFO'
90
+ log_processor = dict(by_epoch=False)
91
+ lr = 0.0002
92
+ max_epochs = 1
93
+ max_length = 1472
94
+ max_norm = 1
95
+ model = dict(
96
+ freeze_llm=True,
97
+ freeze_visual_encoder=True,
98
+ llm=dict(
99
+ pretrained_model_name_or_path='internlm/internlm2-1_8b',
100
+ quantization_config=dict(
101
+ bnb_4bit_compute_dtype='torch.float16',
102
+ bnb_4bit_quant_type='nf4',
103
+ bnb_4bit_use_double_quant=True,
104
+ llm_int8_has_fp16_weight=False,
105
+ llm_int8_threshold=6.0,
106
+ load_in_4bit=True,
107
+ load_in_8bit=False,
108
+ type='transformers.BitsAndBytesConfig'),
109
+ torch_dtype='torch.float16',
110
+ trust_remote_code=True,
111
+ type='transformers.AutoModelForCausalLM.from_pretrained'),
112
+ llm_lora=dict(
113
+ bias='none',
114
+ lora_alpha=256,
115
+ lora_dropout=0.05,
116
+ r=512,
117
+ task_type='CAUSAL_LM',
118
+ type='peft.LoraConfig'),
119
+ pretrained_pth='./work_dirs/pretrain/iter_11628.pth',
120
+ type='xtuner.model.LLaVAModel',
121
+ visual_encoder=dict(
122
+ pretrained_model_name_or_path='google/siglip-so400m-patch14-384',
123
+ type='transformers.SiglipVisionModel.from_pretrained'),
124
+ visual_encoder_lora=dict(
125
+ bias='none',
126
+ lora_alpha=16,
127
+ lora_dropout=0.05,
128
+ r=64,
129
+ type='peft.LoraConfig'))
130
+ optim_type = 'torch.optim.AdamW'
131
+ optim_wrapper = dict(
132
+ optimizer=dict(
133
+ betas=(
134
+ 0.9,
135
+ 0.999,
136
+ ),
137
+ lr=0.0002,
138
+ type='torch.optim.AdamW',
139
+ weight_decay=0),
140
+ type='DeepSpeedOptimWrapper')
141
+ param_scheduler = [
142
+ dict(
143
+ begin=0,
144
+ by_epoch=True,
145
+ convert_to_iter_based=True,
146
+ end=0.03,
147
+ start_factor=1e-05,
148
+ type='mmengine.optim.LinearLR'),
149
+ dict(
150
+ begin=0.03,
151
+ by_epoch=True,
152
+ convert_to_iter_based=True,
153
+ end=1,
154
+ eta_min=0.0,
155
+ type='mmengine.optim.CosineAnnealingLR'),
156
+ ]
157
+ prefetch = 5
158
+ pretrained_pth = './work_dirs/pretrain/iter_11628.pth'
159
+ prompt_template = 'xtuner.utils.PROMPT_TEMPLATE.internlm2_chat'
160
+ randomness = dict(deterministic=False, seed=None)
161
+ resume = False
162
+ runner_type = 'FlexibleRunner'
163
+ save_steps = 500
164
+ save_total_limit = 2
165
+ strategy = dict(
166
+ config=dict(
167
+ bf16=dict(enabled=True),
168
+ fp16=dict(enabled=False, initial_scale_power=16),
169
+ gradient_accumulation_steps='auto',
170
+ gradient_clipping='auto',
171
+ train_micro_batch_size_per_gpu='auto',
172
+ zero_allow_untested_optimizer=True,
173
+ zero_force_ds_cpu_optimizer=False,
174
+ zero_optimization=dict(overlap_comm=True, stage=2)),
175
+ exclude_frozen_parameters=True,
176
+ gradient_accumulation_steps=8,
177
+ gradient_clipping=1,
178
+ train_micro_batch_size_per_gpu=4,
179
+ type='xtuner.engine.DeepSpeedStrategy')
180
+ tokenizer = dict(
181
+ padding_side='right',
182
+ pretrained_model_name_or_path='internlm/internlm2-1_8b',
183
+ trust_remote_code=True,
184
+ type='transformers.AutoTokenizer.from_pretrained')
185
+ train_cfg = dict(max_epochs=1, type='xtuner.engine.runner.TrainLoop')
186
+ train_dataloader = dict(
187
+ batch_size=4,
188
+ collate_fn=dict(type='xtuner.dataset.collate_fns.default_collate_fn'),
189
+ dataset=dict(
190
+ data_path='./LLaVA-Instruct-150K/llava_v1_5_mix665k.json',
191
+ dataset_map_fn='xtuner.dataset.map_fns.llava_map_fn',
192
+ image_folder='./llava_images',
193
+ image_processor=dict(
194
+ pretrained_model_name_or_path='google/siglip-so400m-patch14-384',
195
+ trust_remote_code=True,
196
+ type='transformers.SiglipImageProcessor.from_pretrained'),
197
+ max_length=1472,
198
+ pad_image_to_square=True,
199
+ template_map_fn=dict(
200
+ template='xtuner.utils.PROMPT_TEMPLATE.internlm2_chat',
201
+ type='xtuner.dataset.map_fns.template_map_fn_factory'),
202
+ tokenizer=dict(
203
+ padding_side='right',
204
+ pretrained_model_name_or_path='internlm/internlm2-1_8b',
205
+ trust_remote_code=True,
206
+ type='transformers.AutoTokenizer.from_pretrained'),
207
+ type='xtuner.dataset.LLaVADataset'),
208
+ num_workers=4,
209
+ prefetch_factor=5,
210
+ sampler=dict(
211
+ length_property='modality_length',
212
+ per_device_batch_size=32,
213
+ type='xtuner.dataset.samplers.LengthGroupedSampler'))
214
+ visual_encoder_name_or_path = 'google/siglip-so400m-patch14-384'
215
+ visualizer = dict(
216
+ type='mmengine.visualization.Visualizer',
217
+ vis_backends=[
218
+ dict(type='mmengine.visualization.TensorboardVisBackend'),
219
+ ])
220
+ warmup_ratio = 0.03
221
+ weight_decay = 0
222
+ work_dir = './work_dirs/finetune'
mmbench_results/20240310_025701/args.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_name_or_path": "internlm/internlm2-1_8b",
3
+ "data_path": "MMBench_DEV_EN.tsv",
4
+ "work_dir": "work_dirs/finetune/results",
5
+ "llava": "work_dirs/finetune/hf/",
6
+ "visual_encoder": "google/siglip-so400m-patch14-384",
7
+ "visual_select_layer": -2,
8
+ "prompt_template": "internlm2_chat",
9
+ "stop_words": [
10
+ "<|im_end|>"
11
+ ],
12
+ "torch_dtype": "fp16",
13
+ "bits": null,
14
+ "bot_name": "BOT",
15
+ "offload_folder": null,
16
+ "max_new_tokens": 100,
17
+ "seed": 0,
18
+ "launcher": "pytorch"
19
+ }
mmbench_results/20240310_025701/mmbench_result.json ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "Average": 0.6709621993127147,
3
+ "AR": 0.6984924623115578,
4
+ "CP": 0.8175675675675675,
5
+ "FP-C": 0.5944055944055944,
6
+ "FP-S": 0.6860068259385665,
7
+ "LR": 0.3305084745762712,
8
+ "RR": 0.6521739130434783
9
+ }
mmbench_results/20240310_025701/mmbench_result.xlsx ADDED
Binary file (368 kB). View file
 
mmbench_results/20240310_030410/args.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_name_or_path": "internlm/internlm2-1_8b",
3
+ "data_path": "MMBench_TEST_EN.tsv",
4
+ "work_dir": "work_dirs/finetune/results",
5
+ "llava": "work_dirs/finetune/hf/",
6
+ "visual_encoder": "google/siglip-so400m-patch14-384",
7
+ "visual_select_layer": -2,
8
+ "prompt_template": "internlm2_chat",
9
+ "stop_words": [
10
+ "<|im_end|>"
11
+ ],
12
+ "torch_dtype": "fp16",
13
+ "bits": null,
14
+ "bot_name": "BOT",
15
+ "offload_folder": null,
16
+ "max_new_tokens": 100,
17
+ "seed": 0,
18
+ "launcher": "pytorch"
19
+ }
mmbench_results/20240310_030410/mmbench_result.xlsx ADDED
Binary file (546 kB). View file
 
mmbench_results/20240310_031346/args.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_name_or_path": "internlm/internlm2-1_8b",
3
+ "data_path": "MMBench_DEV_CN.tsv",
4
+ "work_dir": "work_dirs/finetune/results",
5
+ "llava": "work_dirs/finetune/hf/",
6
+ "visual_encoder": "google/siglip-so400m-patch14-384",
7
+ "visual_select_layer": -2,
8
+ "prompt_template": "internlm2_chat",
9
+ "stop_words": [
10
+ "<|im_end|>"
11
+ ],
12
+ "torch_dtype": "fp16",
13
+ "bits": null,
14
+ "bot_name": "BOT",
15
+ "offload_folder": null,
16
+ "max_new_tokens": 100,
17
+ "seed": 0,
18
+ "launcher": "pytorch"
19
+ }
mmbench_results/20240310_031346/mmbench_result.json ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "Average": 0.6408934707903781,
3
+ "AR": 0.6582914572864321,
4
+ "CP": 0.7972972972972973,
5
+ "FP-C": 0.6013986013986014,
6
+ "FP-S": 0.6416382252559727,
7
+ "LR": 0.288135593220339,
8
+ "RR": 0.6173913043478261
9
+ }
mmbench_results/20240310_031346/mmbench_result.xlsx ADDED
Binary file (432 kB). View file
 
mmbench_results/20240310_032208/args.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_name_or_path": "internlm/internlm2-1_8b",
3
+ "data_path": "MMBench_TEST_CN.tsv",
4
+ "work_dir": "work_dirs/finetune/results",
5
+ "llava": "work_dirs/finetune/hf/",
6
+ "visual_encoder": "google/siglip-so400m-patch14-384",
7
+ "visual_select_layer": -2,
8
+ "prompt_template": "internlm2_chat",
9
+ "stop_words": [
10
+ "<|im_end|>"
11
+ ],
12
+ "torch_dtype": "fp16",
13
+ "bits": null,
14
+ "bot_name": "BOT",
15
+ "offload_folder": null,
16
+ "max_new_tokens": 100,
17
+ "seed": 0,
18
+ "launcher": "pytorch"
19
+ }
mmbench_results/20240310_032208/mmbench_result.xlsx ADDED
Binary file (610 kB). View file
 
mmbench_results/20240310_033150/args.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_name_or_path": "internlm/internlm2-1_8b",
3
+ "data_path": "CCBench.tsv",
4
+ "work_dir": "work_dirs/finetune/results",
5
+ "llava": "work_dirs/finetune/hf/",
6
+ "visual_encoder": "google/siglip-so400m-patch14-384",
7
+ "visual_select_layer": -2,
8
+ "prompt_template": "internlm2_chat",
9
+ "stop_words": [
10
+ "<|im_end|>"
11
+ ],
12
+ "torch_dtype": "fp16",
13
+ "bits": null,
14
+ "bot_name": "BOT",
15
+ "offload_folder": null,
16
+ "max_new_tokens": 100,
17
+ "seed": 0,
18
+ "launcher": "pytorch"
19
+ }
mmbench_results/20240310_033150/mmbench_result.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "Average": 0.37254901960784315,
3
+ "Calligraphy Painting": 0.40350877192982454,
4
+ "Cultural Relic": 0.31958762886597936,
5
+ "Food & Clothes": 0.4608695652173913,
6
+ "Historical Figure": 0.05714285714285714,
7
+ "Scenery & Building": 0.3368421052631579,
8
+ "Sketch Reasoning": 0.6222222222222222,
9
+ "Traditional Show": 0.3181818181818182
10
+ }
mmbench_results/20240310_033150/mmbench_result.xlsx ADDED
Binary file (114 kB). View file
 
modified_transformers/src/transformers/models/siglip/modeling_siglip.py ADDED
@@ -0,0 +1,1299 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 Google AI and The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ PyTorch Siglip model."""
16
+
17
+
18
+ import math
19
+ import warnings
20
+ from dataclasses import dataclass
21
+ from typing import Any, Optional, Tuple, Union
22
+
23
+ import numpy as np
24
+ import torch
25
+ import torch.utils.checkpoint
26
+ from torch import nn
27
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
28
+ from torch.nn.init import _calculate_fan_in_and_fan_out
29
+
30
+ from ...activations import ACT2FN
31
+ from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
32
+ from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput
33
+ from ...modeling_utils import PreTrainedModel
34
+ from ...utils import (
35
+ ModelOutput,
36
+ add_code_sample_docstrings,
37
+ add_start_docstrings,
38
+ add_start_docstrings_to_model_forward,
39
+ logging,
40
+ replace_return_docstrings,
41
+ )
42
+ from .configuration_siglip import SiglipConfig, SiglipTextConfig, SiglipVisionConfig
43
+
44
+
45
+ logger = logging.get_logger(__name__)
46
+
47
+ # General docstring
48
+ _CONFIG_FOR_DOC = "SiglipConfig"
49
+ _CHECKPOINT_FOR_DOC = "google/siglip-base-patch16-224"
50
+
51
+ # Image classification docstring
52
+ _IMAGE_CLASS_CHECKPOINT = "google/siglip-base-patch16-224"
53
+ _IMAGE_CLASS_EXPECTED_OUTPUT = "LABEL_1"
54
+
55
+
56
+ SIGLIP_PRETRAINED_MODEL_ARCHIVE_LIST = [
57
+ "google/siglip-base-patch16-224",
58
+ # See all SigLIP models at https://huggingface.co/models?filter=siglip
59
+ ]
60
+
61
+
62
+ def _trunc_normal_(tensor, mean, std, a, b):
63
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
64
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
65
+ def norm_cdf(x):
66
+ # Computes standard normal cumulative distribution function
67
+ return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
68
+
69
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
70
+ warnings.warn(
71
+ "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
72
+ "The distribution of values may be incorrect.",
73
+ stacklevel=2,
74
+ )
75
+
76
+ # Values are generated by using a truncated uniform distribution and
77
+ # then using the inverse CDF for the normal distribution.
78
+ # Get upper and lower cdf values
79
+ l = norm_cdf((a - mean) / std)
80
+ u = norm_cdf((b - mean) / std)
81
+
82
+ # Uniformly fill tensor with values from [l, u], then translate to
83
+ # [2l-1, 2u-1].
84
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
85
+
86
+ # Use inverse cdf transform for normal distribution to get truncated
87
+ # standard normal
88
+ tensor.erfinv_()
89
+
90
+ # Transform to proper mean, std
91
+ tensor.mul_(std * math.sqrt(2.0))
92
+ tensor.add_(mean)
93
+
94
+ # Clamp to ensure it's in the proper range
95
+ tensor.clamp_(min=a, max=b)
96
+
97
+
98
+ def trunc_normal_tf_(
99
+ tensor: torch.Tensor, mean: float = 0.0, std: float = 1.0, a: float = -2.0, b: float = 2.0
100
+ ) -> torch.Tensor:
101
+ """Fills the input Tensor with values drawn from a truncated
102
+ normal distribution. The values are effectively drawn from the
103
+ normal distribution :math:`\\mathcal{N}(\text{mean}, \text{std}^2)`
104
+ with values outside :math:`[a, b]` redrawn until they are within
105
+ the bounds. The method used for generating the random values works
106
+ best when :math:`a \\leq \text{mean} \\leq b`.
107
+
108
+ NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the
109
+ bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0
110
+ and the result is subsquently scaled and shifted by the mean and std args.
111
+
112
+ Args:
113
+ tensor: an n-dimensional `torch.Tensor`
114
+ mean: the mean of the normal distribution
115
+ std: the standard deviation of the normal distribution
116
+ a: the minimum cutoff value
117
+ b: the maximum cutoff value
118
+ """
119
+ with torch.no_grad():
120
+ _trunc_normal_(tensor, 0, 1.0, a, b)
121
+ tensor.mul_(std).add_(mean)
122
+
123
+
124
+ def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"):
125
+ fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
126
+ if mode == "fan_in":
127
+ denom = fan_in
128
+ elif mode == "fan_out":
129
+ denom = fan_out
130
+ elif mode == "fan_avg":
131
+ denom = (fan_in + fan_out) / 2
132
+
133
+ variance = scale / denom
134
+
135
+ if distribution == "truncated_normal":
136
+ # constant is stddev of standard normal truncated to (-2, 2)
137
+ trunc_normal_tf_(tensor, std=math.sqrt(variance) / 0.87962566103423978)
138
+ elif distribution == "normal":
139
+ with torch.no_grad():
140
+ tensor.normal_(std=math.sqrt(variance))
141
+ elif distribution == "uniform":
142
+ bound = math.sqrt(3 * variance)
143
+ with torch.no_grad():
144
+ tensor.uniform_(-bound, bound)
145
+ else:
146
+ raise ValueError(f"invalid distribution {distribution}")
147
+
148
+
149
+ def lecun_normal_(tensor):
150
+ variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal")
151
+
152
+
153
+ def default_flax_embed_init(tensor):
154
+ variance_scaling_(tensor, mode="fan_in", distribution="normal")
155
+
156
+
157
+ @dataclass
158
+ # Copied from transformers.models.clip.modeling_clip.CLIPVisionModelOutput with CLIP->Siglip
159
+ class SiglipVisionModelOutput(ModelOutput):
160
+ """
161
+ Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states.
162
+
163
+ Args:
164
+ image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
165
+ The image embeddings obtained by applying the projection layer to the pooler_output.
166
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
167
+ Sequence of hidden-states at the output of the last layer of the model.
168
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
169
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
170
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
171
+
172
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
173
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
174
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
175
+ sequence_length)`.
176
+
177
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
178
+ heads.
179
+ """
180
+
181
+ image_embeds: Optional[torch.FloatTensor] = None
182
+ last_hidden_state: torch.FloatTensor = None
183
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
184
+ attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
185
+
186
+
187
+ @dataclass
188
+ # Copied from transformers.models.clip.modeling_clip.CLIPTextModelOutput with CLIP->Siglip
189
+ class SiglipTextModelOutput(ModelOutput):
190
+ """
191
+ Base class for text model's outputs that also contains a pooling of the last hidden states.
192
+
193
+ Args:
194
+ text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
195
+ The text embeddings obtained by applying the projection layer to the pooler_output.
196
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
197
+ Sequence of hidden-states at the output of the last layer of the model.
198
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
199
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
200
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
201
+
202
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
203
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
204
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
205
+ sequence_length)`.
206
+
207
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
208
+ heads.
209
+ """
210
+
211
+ text_embeds: Optional[torch.FloatTensor] = None
212
+ last_hidden_state: torch.FloatTensor = None
213
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
214
+ attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
215
+
216
+
217
+ @dataclass
218
+ # Copied from transformers.models.clip.modeling_clip.CLIPOutput with CLIP->Siglip
219
+ class SiglipOutput(ModelOutput):
220
+ """
221
+ Args:
222
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`):
223
+ Contrastive loss for image-text similarity.
224
+ logits_per_image:(`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`):
225
+ The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text
226
+ similarity scores.
227
+ logits_per_text:(`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`):
228
+ The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image
229
+ similarity scores.
230
+ text_embeds(`torch.FloatTensor` of shape `(batch_size, output_dim`):
231
+ The text embeddings obtained by applying the projection layer to the pooled output of [`SiglipTextModel`].
232
+ image_embeds(`torch.FloatTensor` of shape `(batch_size, output_dim`):
233
+ The image embeddings obtained by applying the projection layer to the pooled output of [`SiglipVisionModel`].
234
+ text_model_output(`BaseModelOutputWithPooling`):
235
+ The output of the [`SiglipTextModel`].
236
+ vision_model_output(`BaseModelOutputWithPooling`):
237
+ The output of the [`SiglipVisionModel`].
238
+ """
239
+
240
+ loss: Optional[torch.FloatTensor] = None
241
+ logits_per_image: torch.FloatTensor = None
242
+ logits_per_text: torch.FloatTensor = None
243
+ text_embeds: torch.FloatTensor = None
244
+ image_embeds: torch.FloatTensor = None
245
+ text_model_output: BaseModelOutputWithPooling = None
246
+ vision_model_output: BaseModelOutputWithPooling = None
247
+
248
+ def to_tuple(self) -> Tuple[Any]:
249
+ return tuple(
250
+ self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple()
251
+ for k in self.keys()
252
+ )
253
+
254
+
255
+ class SiglipVisionEmbeddings(nn.Module):
256
+ def __init__(self, config: SiglipVisionConfig):
257
+ super().__init__()
258
+ self.config = config
259
+ self.embed_dim = config.hidden_size
260
+ self.image_size = config.image_size
261
+ self.patch_size = config.patch_size
262
+
263
+ self.patch_embedding = nn.Conv2d(
264
+ in_channels=config.num_channels,
265
+ out_channels=self.embed_dim,
266
+ kernel_size=self.patch_size,
267
+ stride=self.patch_size,
268
+ padding="valid",
269
+ )
270
+
271
+ self.num_patches = (self.image_size // self.patch_size) ** 2
272
+ self.num_positions = self.num_patches
273
+ self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
274
+ self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)
275
+
276
+ def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
277
+ target_dtype = self.patch_embedding.weight.dtype
278
+ patch_embeds = self.patch_embedding(pixel_values.to(target_dtype)) # shape = [*, width, grid, grid]
279
+ embeddings = patch_embeds.flatten(2).transpose(1, 2)
280
+
281
+ embeddings = embeddings + self.position_embedding(self.position_ids)
282
+ return embeddings
283
+
284
+
285
+ # Copied from transformers.models.clip.modeling_clip.CLIPTextEmbeddings with CLIP->Siglip
286
+ class SiglipTextEmbeddings(nn.Module):
287
+ def __init__(self, config: SiglipTextConfig):
288
+ super().__init__()
289
+ embed_dim = config.hidden_size
290
+
291
+ self.token_embedding = nn.Embedding(config.vocab_size, embed_dim)
292
+ self.position_embedding = nn.Embedding(config.max_position_embeddings, embed_dim)
293
+
294
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
295
+ self.register_buffer(
296
+ "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
297
+ )
298
+
299
+ def forward(
300
+ self,
301
+ input_ids: Optional[torch.LongTensor] = None,
302
+ position_ids: Optional[torch.LongTensor] = None,
303
+ inputs_embeds: Optional[torch.FloatTensor] = None,
304
+ ) -> torch.Tensor:
305
+ seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]
306
+
307
+ if position_ids is None:
308
+ position_ids = self.position_ids[:, :seq_length]
309
+
310
+ if inputs_embeds is None:
311
+ inputs_embeds = self.token_embedding(input_ids)
312
+
313
+ position_embeddings = self.position_embedding(position_ids)
314
+ embeddings = inputs_embeds + position_embeddings
315
+
316
+ return embeddings
317
+
318
+
319
+ class SiglipAttention(nn.Module):
320
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
321
+
322
+ # Copied from transformers.models.clip.modeling_clip.CLIPAttention.__init__
323
+ def __init__(self, config):
324
+ super().__init__()
325
+ self.config = config
326
+ self.embed_dim = config.hidden_size
327
+ self.num_heads = config.num_attention_heads
328
+ self.head_dim = self.embed_dim // self.num_heads
329
+ if self.head_dim * self.num_heads != self.embed_dim:
330
+ raise ValueError(
331
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
332
+ f" {self.num_heads})."
333
+ )
334
+ self.scale = self.head_dim**-0.5
335
+ self.dropout = config.attention_dropout
336
+
337
+ self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
338
+ self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
339
+ self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
340
+ self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
341
+
342
+ def forward(
343
+ self,
344
+ hidden_states: torch.Tensor,
345
+ attention_mask: Optional[torch.Tensor] = None,
346
+ output_attentions: Optional[bool] = False,
347
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
348
+ """Input shape: Batch x Time x Channel"""
349
+
350
+ batch_size, q_len, _ = hidden_states.size()
351
+
352
+ query_states = self.q_proj(hidden_states)
353
+ key_states = self.k_proj(hidden_states)
354
+ value_states = self.v_proj(hidden_states)
355
+
356
+ query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
357
+ key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
358
+ value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
359
+
360
+ k_v_seq_len = key_states.shape[-2]
361
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale
362
+
363
+ if attn_weights.size() != (batch_size, self.num_heads, q_len, k_v_seq_len):
364
+ raise ValueError(
365
+ f"Attention weights should be of size {(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is"
366
+ f" {attn_weights.size()}"
367
+ )
368
+
369
+ if attention_mask is not None:
370
+ if attention_mask.size() != (batch_size, 1, q_len, k_v_seq_len):
371
+ raise ValueError(
372
+ f"Attention mask should be of size {(batch_size, 1, q_len, k_v_seq_len)}, but is {attention_mask.size()}"
373
+ )
374
+ attn_weights = attn_weights + attention_mask
375
+
376
+ # upcast attention to fp32
377
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
378
+ attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
379
+ attn_output = torch.matmul(attn_weights, value_states)
380
+
381
+ if attn_output.size() != (batch_size, self.num_heads, q_len, self.head_dim):
382
+ raise ValueError(
383
+ f"`attn_output` should be of size {(batch_size, self.num_heads, q_len, self.head_dim)}, but is"
384
+ f" {attn_output.size()}"
385
+ )
386
+
387
+ attn_output = attn_output.transpose(1, 2).contiguous()
388
+ attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim)
389
+
390
+ attn_output = self.out_proj(attn_output)
391
+
392
+ return attn_output, attn_weights
393
+
394
+
395
+ # Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->Siglip
396
+ class SiglipMLP(nn.Module):
397
+ def __init__(self, config):
398
+ super().__init__()
399
+ self.config = config
400
+ self.activation_fn = ACT2FN[config.hidden_act]
401
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
402
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
403
+
404
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
405
+ hidden_states = self.fc1(hidden_states)
406
+ hidden_states = self.activation_fn(hidden_states)
407
+ hidden_states = self.fc2(hidden_states)
408
+ return hidden_states
409
+
410
+
411
+ # Copied from transformers.models.clip.modeling_clip.CLIPEncoderLayer with CLIP->Siglip
412
+ class SiglipEncoderLayer(nn.Module):
413
+ def __init__(self, config: SiglipConfig):
414
+ super().__init__()
415
+ self.embed_dim = config.hidden_size
416
+ self.self_attn = SiglipAttention(config)
417
+ self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
418
+ self.mlp = SiglipMLP(config)
419
+ self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
420
+
421
+ # Ignore copy
422
+ def forward(
423
+ self,
424
+ hidden_states: torch.Tensor,
425
+ attention_mask: torch.Tensor,
426
+ output_attentions: Optional[bool] = False,
427
+ ) -> Tuple[torch.FloatTensor]:
428
+ """
429
+ Args:
430
+ hidden_states (`torch.FloatTensor`):
431
+ Input to the layer of shape `(batch, seq_len, embed_dim)`.
432
+ attention_mask (`torch.FloatTensor`):
433
+ Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where padding elements are indicated by very large negative values.
434
+ output_attentions (`bool`, *optional*, defaults to `False`):
435
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
436
+ returned tensors for more detail.
437
+ """
438
+ residual = hidden_states
439
+
440
+ hidden_states = self.layer_norm1(hidden_states)
441
+ hidden_states, attn_weights = self.self_attn(
442
+ hidden_states=hidden_states,
443
+ attention_mask=attention_mask,
444
+ output_attentions=output_attentions,
445
+ )
446
+ hidden_states = residual + hidden_states
447
+
448
+ residual = hidden_states
449
+ hidden_states = self.layer_norm2(hidden_states)
450
+ hidden_states = self.mlp(hidden_states)
451
+ hidden_states = residual + hidden_states
452
+
453
+ outputs = (hidden_states,)
454
+
455
+ if output_attentions:
456
+ outputs += (attn_weights,)
457
+
458
+ return outputs
459
+
460
+
461
+ class SiglipPreTrainedModel(PreTrainedModel):
462
+ """
463
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
464
+ models.
465
+ """
466
+
467
+ config_class = SiglipConfig
468
+ base_model_prefix = "siglip"
469
+ supports_gradient_checkpointing = True
470
+
471
+ def _init_weights(self, module):
472
+ """Initialize the weights"""
473
+ if isinstance(module, SiglipVisionEmbeddings):
474
+ width = (
475
+ self.config.vision_config.hidden_size
476
+ if isinstance(self.config, SiglipConfig)
477
+ else self.config.hidden_size
478
+ )
479
+ nn.init.normal_(module.position_embedding.weight, std=1 / np.sqrt(width))
480
+ elif isinstance(module, nn.Embedding):
481
+ default_flax_embed_init(module.weight)
482
+ elif isinstance(module, SiglipAttention):
483
+ nn.init.xavier_uniform_(module.q_proj.weight)
484
+ nn.init.xavier_uniform_(module.k_proj.weight)
485
+ nn.init.xavier_uniform_(module.v_proj.weight)
486
+ nn.init.xavier_uniform_(module.out_proj.weight)
487
+ nn.init.zeros_(module.q_proj.bias)
488
+ nn.init.zeros_(module.k_proj.bias)
489
+ nn.init.zeros_(module.v_proj.bias)
490
+ nn.init.zeros_(module.out_proj.bias)
491
+ elif isinstance(module, SiglipMLP):
492
+ nn.init.xavier_uniform_(module.fc1.weight)
493
+ nn.init.xavier_uniform_(module.fc2.weight)
494
+ nn.init.normal_(module.fc1.bias, std=1e-6)
495
+ nn.init.normal_(module.fc2.bias, std=1e-6)
496
+ elif isinstance(module, SiglipMultiheadAttentionPoolingHead):
497
+ nn.init.xavier_uniform_(module.probe.data)
498
+ nn.init.xavier_uniform_(module.attention.in_proj_weight.data)
499
+ nn.init.zeros_(module.attention.in_proj_bias.data)
500
+ elif isinstance(module, SiglipModel):
501
+ logit_scale_init = torch.log(torch.tensor(1.0))
502
+ module.logit_scale.data.fill_(logit_scale_init)
503
+ module.logit_bias.data.zero_()
504
+ elif isinstance(module, (nn.Linear, nn.Conv2d)):
505
+ lecun_normal_(module.weight)
506
+ if module.bias is not None:
507
+ nn.init.zeros_(module.bias)
508
+ elif isinstance(module, nn.LayerNorm):
509
+ module.bias.data.zero_()
510
+ module.weight.data.fill_(1.0)
511
+
512
+
513
+ SIGLIP_START_DOCSTRING = r"""
514
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
515
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
516
+ etc.)
517
+
518
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
519
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
520
+ and behavior.
521
+
522
+ Parameters:
523
+ config ([`SiglipConfig`]): Model configuration class with all the parameters of the model.
524
+ Initializing with a config file does not load the weights associated with the model, only the
525
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
526
+ """
527
+
528
+ SIGLIP_TEXT_INPUTS_DOCSTRING = r"""
529
+ Args:
530
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
531
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
532
+ it.
533
+
534
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
535
+ [`PreTrainedTokenizer.__call__`] for details.
536
+
537
+ [What are input IDs?](../glossary#input-ids)
538
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
539
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
540
+
541
+ - 1 for tokens that are **not masked**,
542
+ - 0 for tokens that are **masked**.
543
+
544
+ [What are attention masks?](../glossary#attention-mask)
545
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
546
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
547
+ config.max_position_embeddings - 1]`.
548
+
549
+ [What are position IDs?](../glossary#position-ids)
550
+ output_attentions (`bool`, *optional*):
551
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
552
+ tensors for more detail.
553
+ output_hidden_states (`bool`, *optional*):
554
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
555
+ more detail.
556
+ return_dict (`bool`, *optional*):
557
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
558
+ """
559
+
560
+ SIGLIP_VISION_INPUTS_DOCSTRING = r"""
561
+ Args:
562
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
563
+ Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
564
+ [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
565
+ output_attentions (`bool`, *optional*):
566
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
567
+ tensors for more detail.
568
+ output_hidden_states (`bool`, *optional*):
569
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
570
+ more detail.
571
+ return_dict (`bool`, *optional*):
572
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
573
+ """
574
+
575
+ SIGLIP_INPUTS_DOCSTRING = r"""
576
+ Args:
577
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
578
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
579
+ it.
580
+
581
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
582
+ [`PreTrainedTokenizer.__call__`] for details.
583
+
584
+ [What are input IDs?](../glossary#input-ids)
585
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
586
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
587
+
588
+ - 1 for tokens that are **not masked**,
589
+ - 0 for tokens that are **masked**.
590
+
591
+ [What are attention masks?](../glossary#attention-mask)
592
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
593
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
594
+ config.max_position_embeddings - 1]`.
595
+
596
+ [What are position IDs?](../glossary#position-ids)
597
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
598
+ Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
599
+ [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
600
+ return_loss (`bool`, *optional*):
601
+ Whether or not to return the contrastive loss.
602
+ output_attentions (`bool`, *optional*):
603
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
604
+ tensors for more detail.
605
+ output_hidden_states (`bool`, *optional*):
606
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
607
+ more detail.
608
+ return_dict (`bool`, *optional*):
609
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
610
+ """
611
+
612
+
613
+ # Copied from transformers.models.clip.modeling_clip.CLIPEncoder with CLIP->Siglip
614
+ class SiglipEncoder(nn.Module):
615
+ """
616
+ Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
617
+ [`SiglipEncoderLayer`].
618
+
619
+ Args:
620
+ config: SiglipConfig
621
+ """
622
+
623
+ def __init__(self, config: SiglipConfig):
624
+ super().__init__()
625
+ self.config = config
626
+ self.layers = nn.ModuleList([SiglipEncoderLayer(config) for _ in range(config.num_hidden_layers)])
627
+ self.gradient_checkpointing = False
628
+
629
+ # Ignore copy
630
+ def forward(
631
+ self,
632
+ inputs_embeds,
633
+ attention_mask: Optional[torch.Tensor] = None,
634
+ output_attentions: Optional[bool] = None,
635
+ output_hidden_states: Optional[bool] = None,
636
+ return_dict: Optional[bool] = None,
637
+ ) -> Union[Tuple, BaseModelOutput]:
638
+ r"""
639
+ Args:
640
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
641
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
642
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
643
+ than the model's internal embedding lookup matrix.
644
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
645
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
646
+
647
+ - 1 for tokens that are **not masked**,
648
+ - 0 for tokens that are **masked**.
649
+
650
+ [What are attention masks?](../glossary#attention-mask)
651
+ output_attentions (`bool`, *optional*):
652
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
653
+ returned tensors for more detail.
654
+ output_hidden_states (`bool`, *optional*):
655
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
656
+ for more detail.
657
+ return_dict (`bool`, *optional*):
658
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
659
+ """
660
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
661
+ output_hidden_states = (
662
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
663
+ )
664
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
665
+
666
+ encoder_states = () if output_hidden_states else None
667
+ all_attentions = () if output_attentions else None
668
+
669
+ hidden_states = inputs_embeds
670
+ for encoder_layer in self.layers:
671
+ if output_hidden_states:
672
+ encoder_states = encoder_states + (hidden_states,)
673
+ if self.gradient_checkpointing and self.training:
674
+ layer_outputs = self._gradient_checkpointing_func(
675
+ encoder_layer.__call__,
676
+ hidden_states,
677
+ attention_mask,
678
+ output_attentions,
679
+ )
680
+ else:
681
+ layer_outputs = encoder_layer(
682
+ hidden_states,
683
+ attention_mask,
684
+ output_attentions=output_attentions,
685
+ )
686
+
687
+ hidden_states = layer_outputs[0]
688
+
689
+ if output_attentions:
690
+ all_attentions = all_attentions + (layer_outputs[1],)
691
+
692
+ if output_hidden_states:
693
+ encoder_states = encoder_states + (hidden_states,)
694
+
695
+ if not return_dict:
696
+ return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
697
+ return BaseModelOutput(
698
+ last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
699
+ )
700
+
701
+
702
+ class SiglipTextTransformer(nn.Module):
703
+ def __init__(self, config: SiglipTextConfig):
704
+ super().__init__()
705
+ self.config = config
706
+ embed_dim = config.hidden_size
707
+ self.embeddings = SiglipTextEmbeddings(config)
708
+ self.encoder = SiglipEncoder(config)
709
+ self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
710
+
711
+ self.head = nn.Linear(embed_dim, embed_dim)
712
+
713
+ @add_start_docstrings_to_model_forward(SIGLIP_TEXT_INPUTS_DOCSTRING)
714
+ @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=SiglipTextConfig)
715
+ def forward(
716
+ self,
717
+ input_ids: Optional[torch.Tensor] = None,
718
+ attention_mask: Optional[torch.Tensor] = None,
719
+ position_ids: Optional[torch.Tensor] = None,
720
+ output_attentions: Optional[bool] = None,
721
+ output_hidden_states: Optional[bool] = None,
722
+ return_dict: Optional[bool] = None,
723
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
724
+ r"""
725
+ Returns:
726
+
727
+ """
728
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
729
+ output_hidden_states = (
730
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
731
+ )
732
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
733
+
734
+ if input_ids is None:
735
+ raise ValueError("You have to specify input_ids")
736
+
737
+ input_shape = input_ids.size()
738
+ input_ids = input_ids.view(-1, input_shape[-1])
739
+
740
+ hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids)
741
+
742
+ # note: SigLIP's text model does not use a causal mask, unlike the original CLIP model.
743
+ # expand attention_mask
744
+ if attention_mask is not None:
745
+ # [batch_size, seq_len] -> [batch_size, 1, tgt_seq_len, src_seq_len]
746
+ attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype)
747
+
748
+ encoder_outputs = self.encoder(
749
+ inputs_embeds=hidden_states,
750
+ attention_mask=attention_mask,
751
+ output_attentions=output_attentions,
752
+ output_hidden_states=output_hidden_states,
753
+ return_dict=return_dict,
754
+ )
755
+
756
+ last_hidden_state = encoder_outputs[0]
757
+ last_hidden_state = self.final_layer_norm(last_hidden_state)
758
+
759
+ # Assuming "sticky" EOS tokenization, last token is always EOS.
760
+ pooled_output = last_hidden_state[:, -1, :]
761
+ pooled_output = self.head(pooled_output)
762
+
763
+ if not return_dict:
764
+ return (last_hidden_state, pooled_output) + encoder_outputs[1:]
765
+
766
+ return BaseModelOutputWithPooling(
767
+ last_hidden_state=last_hidden_state,
768
+ pooler_output=pooled_output,
769
+ hidden_states=encoder_outputs.hidden_states,
770
+ attentions=encoder_outputs.attentions,
771
+ )
772
+
773
+
774
+ @add_start_docstrings(
775
+ """The text model from SigLIP without any head or projection on top.""",
776
+ SIGLIP_START_DOCSTRING,
777
+ )
778
+ class SiglipTextModel(SiglipPreTrainedModel):
779
+ config_class = SiglipTextConfig
780
+
781
+ _no_split_modules = ["SiglipTextEmbeddings", "SiglipEncoderLayer"]
782
+
783
+ def __init__(self, config: SiglipTextConfig):
784
+ super().__init__(config)
785
+ self.text_model = SiglipTextTransformer(config)
786
+ # Initialize weights and apply final processing
787
+ self.post_init()
788
+
789
+ def get_input_embeddings(self) -> nn.Module:
790
+ return self.text_model.embeddings.token_embedding
791
+
792
+ def set_input_embeddings(self, value):
793
+ self.text_model.embeddings.token_embedding = value
794
+
795
+ @add_start_docstrings_to_model_forward(SIGLIP_TEXT_INPUTS_DOCSTRING)
796
+ @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=SiglipTextConfig)
797
+ def forward(
798
+ self,
799
+ input_ids: Optional[torch.Tensor] = None,
800
+ attention_mask: Optional[torch.Tensor] = None,
801
+ position_ids: Optional[torch.Tensor] = None,
802
+ output_attentions: Optional[bool] = None,
803
+ output_hidden_states: Optional[bool] = None,
804
+ return_dict: Optional[bool] = None,
805
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
806
+ r"""
807
+ Returns:
808
+
809
+ Examples:
810
+
811
+ ```python
812
+ >>> from transformers import AutoTokenizer, SiglipTextModel
813
+
814
+ >>> model = SiglipTextModel.from_pretrained("google/siglip-base-patch16-224")
815
+ >>> tokenizer = AutoTokenizer.from_pretrained("google/siglip-base-patch16-224")
816
+
817
+ >>> # important: make sure to set padding="max_length" as that's how the model was trained
818
+ >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding="max_length", return_tensors="pt")
819
+
820
+ >>> outputs = model(**inputs)
821
+ >>> last_hidden_state = outputs.last_hidden_state
822
+ >>> pooled_output = outputs.pooler_output # pooled (EOS token) states
823
+ ```"""
824
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
825
+
826
+ return self.text_model(
827
+ input_ids=input_ids,
828
+ attention_mask=attention_mask,
829
+ position_ids=position_ids,
830
+ output_attentions=output_attentions,
831
+ output_hidden_states=output_hidden_states,
832
+ return_dict=return_dict,
833
+ )
834
+
835
+
836
+ class SiglipVisionTransformer(nn.Module):
837
+ def __init__(self, config: SiglipVisionConfig):
838
+ super().__init__()
839
+ self.config = config
840
+ embed_dim = config.hidden_size
841
+
842
+ self.embeddings = SiglipVisionEmbeddings(config)
843
+ self.encoder = SiglipEncoder(config)
844
+ self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
845
+ self.head = SiglipMultiheadAttentionPoolingHead(config)
846
+
847
+ @add_start_docstrings_to_model_forward(SIGLIP_VISION_INPUTS_DOCSTRING)
848
+ @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=SiglipVisionConfig)
849
+ def forward(
850
+ self,
851
+ pixel_values,
852
+ output_attentions: Optional[bool] = None,
853
+ output_hidden_states: Optional[bool] = None,
854
+ return_dict: Optional[bool] = None,
855
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
856
+ r"""
857
+ Returns:
858
+
859
+ """
860
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
861
+ output_hidden_states = (
862
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
863
+ )
864
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
865
+
866
+ hidden_states = self.embeddings(pixel_values)
867
+
868
+ encoder_outputs = self.encoder(
869
+ inputs_embeds=hidden_states,
870
+ output_attentions=output_attentions,
871
+ output_hidden_states=output_hidden_states,
872
+ return_dict=return_dict,
873
+ )
874
+
875
+ last_hidden_state = encoder_outputs[0]
876
+ last_hidden_state = self.post_layernorm(last_hidden_state)
877
+
878
+ pooled_output = self.head(last_hidden_state)
879
+
880
+ if not return_dict:
881
+ return (last_hidden_state, pooled_output) + encoder_outputs[1:]
882
+
883
+ return BaseModelOutputWithPooling(
884
+ last_hidden_state=last_hidden_state,
885
+ pooler_output=pooled_output,
886
+ hidden_states=encoder_outputs.hidden_states,
887
+ attentions=encoder_outputs.attentions,
888
+ )
889
+
890
+
891
+ class SiglipMultiheadAttentionPoolingHead(nn.Module):
892
+ """Multihead Attention Pooling."""
893
+
894
+ def __init__(self, config: SiglipVisionConfig):
895
+ super().__init__()
896
+
897
+ self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size))
898
+ self.attention = torch.nn.MultiheadAttention(config.hidden_size, config.num_attention_heads, batch_first=True)
899
+ self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
900
+ self.mlp = SiglipMLP(config)
901
+
902
+ def forward(self, hidden_state):
903
+ batch_size = hidden_state.shape[0]
904
+ probe = self.probe.repeat(batch_size, 1, 1)
905
+
906
+ hidden_state = self.attention(probe, hidden_state, hidden_state)[0]
907
+
908
+ residual = hidden_state
909
+ hidden_state = self.layernorm(hidden_state)
910
+ hidden_state = residual + self.mlp(hidden_state)
911
+
912
+ return hidden_state[:, 0]
913
+
914
+
915
+ @add_start_docstrings(
916
+ """The vision model from SigLIP without any head or projection on top.""",
917
+ SIGLIP_START_DOCSTRING,
918
+ )
919
+ class SiglipVisionModel(SiglipPreTrainedModel):
920
+ config_class = SiglipVisionConfig
921
+ main_input_name = "pixel_values"
922
+
923
+ def __init__(self, config: SiglipVisionConfig):
924
+ super().__init__(config)
925
+
926
+ self.vision_model = SiglipVisionTransformer(config)
927
+
928
+ # Initialize weights and apply final processing
929
+ self.post_init()
930
+
931
+ def get_input_embeddings(self) -> nn.Module:
932
+ return self.vision_model.embeddings.patch_embedding
933
+
934
+ @add_start_docstrings_to_model_forward(SIGLIP_VISION_INPUTS_DOCSTRING)
935
+ @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=SiglipVisionConfig)
936
+ def forward(
937
+ self,
938
+ pixel_values,
939
+ output_attentions: Optional[bool] = None,
940
+ output_hidden_states: Optional[bool] = None,
941
+ return_dict: Optional[bool] = None,
942
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
943
+ r"""
944
+ Returns:
945
+
946
+ Examples:
947
+
948
+ ```python
949
+ >>> from PIL import Image
950
+ >>> import requests
951
+ >>> from transformers import AutoProcessor, SiglipVisionModel
952
+
953
+ >>> model = SiglipVisionModel.from_pretrained("google/siglip-base-patch16-224")
954
+ >>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224")
955
+
956
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
957
+ >>> image = Image.open(requests.get(url, stream=True).raw)
958
+
959
+ >>> inputs = processor(images=image, return_tensors="pt")
960
+
961
+ >>> outputs = model(**inputs)
962
+ >>> last_hidden_state = outputs.last_hidden_state
963
+ >>> pooled_output = outputs.pooler_output # pooled features
964
+ ```"""
965
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
966
+
967
+ return self.vision_model(
968
+ pixel_values=pixel_values,
969
+ output_attentions=output_attentions,
970
+ output_hidden_states=output_hidden_states,
971
+ return_dict=return_dict,
972
+ )
973
+
974
+
975
+ @add_start_docstrings(SIGLIP_START_DOCSTRING)
976
+ class SiglipModel(SiglipPreTrainedModel):
977
+ config_class = SiglipConfig
978
+
979
+ def __init__(self, config: SiglipConfig):
980
+ super().__init__(config)
981
+
982
+ if not isinstance(config.text_config, SiglipTextConfig):
983
+ raise ValueError(
984
+ "config.text_config is expected to be of type SiglipTextConfig but is of type"
985
+ f" {type(config.text_config)}."
986
+ )
987
+
988
+ if not isinstance(config.vision_config, SiglipVisionConfig):
989
+ raise ValueError(
990
+ "config.vision_config is expected to be of type SiglipVisionConfig but is of type"
991
+ f" {type(config.vision_config)}."
992
+ )
993
+
994
+ text_config = config.text_config
995
+ vision_config = config.vision_config
996
+
997
+ self.text_model = SiglipTextTransformer(text_config)
998
+ self.vision_model = SiglipVisionTransformer(vision_config)
999
+
1000
+ self.logit_scale = nn.Parameter(torch.randn(1))
1001
+ self.logit_bias = nn.Parameter(torch.randn(1))
1002
+
1003
+ # Initialize weights and apply final processing
1004
+ self.post_init()
1005
+
1006
+ @add_start_docstrings_to_model_forward(SIGLIP_TEXT_INPUTS_DOCSTRING)
1007
+ def get_text_features(
1008
+ self,
1009
+ input_ids: Optional[torch.Tensor] = None,
1010
+ attention_mask: Optional[torch.Tensor] = None,
1011
+ position_ids: Optional[torch.Tensor] = None,
1012
+ output_attentions: Optional[bool] = None,
1013
+ output_hidden_states: Optional[bool] = None,
1014
+ return_dict: Optional[bool] = None,
1015
+ ) -> torch.FloatTensor:
1016
+ r"""
1017
+ Returns:
1018
+ text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by
1019
+ applying the projection layer to the pooled output of [`SiglipTextModel`].
1020
+
1021
+ Examples:
1022
+
1023
+ ```python
1024
+ >>> from transformers import AutoTokenizer, AutoModel
1025
+ >>> import torch
1026
+
1027
+ >>> model = AutoModel.from_pretrained("google/siglip-base-patch16-224")
1028
+ >>> tokenizer = AutoTokenizer.from_pretrained("google/siglip-base-patch16-224")
1029
+
1030
+ >>> # important: make sure to set padding="max_length" as that's how the model was trained
1031
+ >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding="max_length", return_tensors="pt")
1032
+ >>> with torch.no_grad():
1033
+ ... text_features = model.get_text_features(**inputs)
1034
+ ```"""
1035
+ # Use SigLIP model's config for some fields (if specified) instead of those of vision & text components.
1036
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1037
+ output_hidden_states = (
1038
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1039
+ )
1040
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1041
+
1042
+ text_outputs = self.text_model(
1043
+ input_ids=input_ids,
1044
+ attention_mask=attention_mask,
1045
+ position_ids=position_ids,
1046
+ output_attentions=output_attentions,
1047
+ output_hidden_states=output_hidden_states,
1048
+ return_dict=return_dict,
1049
+ )
1050
+
1051
+ pooled_output = text_outputs[1]
1052
+
1053
+ return pooled_output
1054
+
1055
+ @add_start_docstrings_to_model_forward(SIGLIP_VISION_INPUTS_DOCSTRING)
1056
+ def get_image_features(
1057
+ self,
1058
+ pixel_values: Optional[torch.FloatTensor] = None,
1059
+ output_attentions: Optional[bool] = None,
1060
+ output_hidden_states: Optional[bool] = None,
1061
+ return_dict: Optional[bool] = None,
1062
+ ) -> torch.FloatTensor:
1063
+ r"""
1064
+ Returns:
1065
+ image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by
1066
+ applying the projection layer to the pooled output of [`SiglipVisionModel`].
1067
+
1068
+ Examples:
1069
+
1070
+ ```python
1071
+ >>> from PIL import Image
1072
+ >>> import requests
1073
+ >>> from transformers import AutoProcessor, AutoModel
1074
+ >>> import torch
1075
+
1076
+ >>> model = AutoModel.from_pretrained("google/siglip-base-patch16-224")
1077
+ >>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224")
1078
+
1079
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
1080
+ >>> image = Image.open(requests.get(url, stream=True).raw)
1081
+
1082
+ >>> inputs = processor(images=image, return_tensors="pt")
1083
+
1084
+ >>> with torch.no_grad():
1085
+ ... image_features = model.get_image_features(**inputs)
1086
+ ```"""
1087
+ # Use SiglipModel's config for some fields (if specified) instead of those of vision & text components.
1088
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1089
+ output_hidden_states = (
1090
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1091
+ )
1092
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1093
+
1094
+ vision_outputs = self.vision_model(
1095
+ pixel_values=pixel_values,
1096
+ output_attentions=output_attentions,
1097
+ output_hidden_states=output_hidden_states,
1098
+ return_dict=return_dict,
1099
+ )
1100
+
1101
+ pooled_output = vision_outputs[1]
1102
+
1103
+ return pooled_output
1104
+
1105
+ @add_start_docstrings_to_model_forward(SIGLIP_INPUTS_DOCSTRING)
1106
+ @replace_return_docstrings(output_type=SiglipOutput, config_class=SiglipConfig)
1107
+ def forward(
1108
+ self,
1109
+ input_ids: Optional[torch.LongTensor] = None,
1110
+ pixel_values: Optional[torch.FloatTensor] = None,
1111
+ attention_mask: Optional[torch.Tensor] = None,
1112
+ position_ids: Optional[torch.LongTensor] = None,
1113
+ return_loss: Optional[bool] = None,
1114
+ output_attentions: Optional[bool] = None,
1115
+ output_hidden_states: Optional[bool] = None,
1116
+ return_dict: Optional[bool] = None,
1117
+ ) -> Union[Tuple, SiglipOutput]:
1118
+ r"""
1119
+ Returns:
1120
+
1121
+ Examples:
1122
+
1123
+ ```python
1124
+ >>> from PIL import Image
1125
+ >>> import requests
1126
+ >>> from transformers import AutoProcessor, AutoModel
1127
+ >>> import torch
1128
+
1129
+ >>> model = AutoModel.from_pretrained("google/siglip-base-patch16-224")
1130
+ >>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224")
1131
+
1132
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
1133
+ >>> image = Image.open(requests.get(url, stream=True).raw)
1134
+
1135
+ >>> texts = ["a photo of 2 cats", "a photo of 2 dogs"]
1136
+ >>> # important: we pass `padding=max_length` since the model was trained with this
1137
+ >>> inputs = processor(text=texts, images=image, padding="max_length", return_tensors="pt")
1138
+
1139
+ >>> with torch.no_grad():
1140
+ ... outputs = model(**inputs)
1141
+
1142
+ >>> logits_per_image = outputs.logits_per_image
1143
+ >>> probs = torch.sigmoid(logits_per_image) # these are the probabilities
1144
+ >>> print(f"{probs[0][0]:.1%} that image 0 is '{texts[0]}'")
1145
+ 31.9% that image 0 is 'a photo of 2 cats'
1146
+ ```"""
1147
+ # Use SigLIP model's config for some fields (if specified) instead of those of vision & text components.
1148
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1149
+ output_hidden_states = (
1150
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1151
+ )
1152
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1153
+
1154
+ vision_outputs = self.vision_model(
1155
+ pixel_values=pixel_values,
1156
+ output_attentions=output_attentions,
1157
+ output_hidden_states=output_hidden_states,
1158
+ return_dict=return_dict,
1159
+ )
1160
+
1161
+ text_outputs = self.text_model(
1162
+ input_ids=input_ids,
1163
+ attention_mask=attention_mask,
1164
+ position_ids=position_ids,
1165
+ output_attentions=output_attentions,
1166
+ output_hidden_states=output_hidden_states,
1167
+ return_dict=return_dict,
1168
+ )
1169
+
1170
+ image_embeds = vision_outputs[1]
1171
+ text_embeds = text_outputs[1]
1172
+
1173
+ # normalized features
1174
+ image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
1175
+ text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
1176
+
1177
+ # cosine similarity as logits
1178
+ logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * self.logit_scale.exp() + self.logit_bias
1179
+ logits_per_image = logits_per_text.t()
1180
+
1181
+ loss = None
1182
+ if return_loss:
1183
+ raise NotImplementedError("SigLIP loss to be implemented")
1184
+
1185
+ if not return_dict:
1186
+ output = (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs)
1187
+ return ((loss,) + output) if loss is not None else output
1188
+
1189
+ return SiglipOutput(
1190
+ loss=loss,
1191
+ logits_per_image=logits_per_image,
1192
+ logits_per_text=logits_per_text,
1193
+ text_embeds=text_embeds,
1194
+ image_embeds=image_embeds,
1195
+ text_model_output=text_outputs,
1196
+ vision_model_output=vision_outputs,
1197
+ )
1198
+
1199
+
1200
+ @add_start_docstrings(
1201
+ """
1202
+ SigLIP vision encoder with an image classification head on top (a linear layer on top of the pooled final hidden states of
1203
+ the patch tokens) e.g. for ImageNet.
1204
+ """,
1205
+ SIGLIP_START_DOCSTRING,
1206
+ )
1207
+ class SiglipForImageClassification(SiglipPreTrainedModel):
1208
+ main_input_name = "pixel_values"
1209
+
1210
+ def __init__(self, config: SiglipConfig) -> None:
1211
+ super().__init__(config)
1212
+
1213
+ self.num_labels = config.num_labels
1214
+ self.vision_model = SiglipVisionTransformer(config.vision_config)
1215
+
1216
+ # Classifier head
1217
+ self.classifier = (
1218
+ nn.Linear(config.vision_config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity()
1219
+ )
1220
+
1221
+ # Initialize weights and apply final processing
1222
+ self.post_init()
1223
+
1224
+ @add_start_docstrings_to_model_forward(SIGLIP_INPUTS_DOCSTRING)
1225
+ @add_code_sample_docstrings(
1226
+ checkpoint=_IMAGE_CLASS_CHECKPOINT,
1227
+ output_type=ImageClassifierOutput,
1228
+ config_class=_CONFIG_FOR_DOC,
1229
+ expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT,
1230
+ )
1231
+ def forward(
1232
+ self,
1233
+ pixel_values: Optional[torch.Tensor] = None,
1234
+ labels: Optional[torch.Tensor] = None,
1235
+ output_attentions: Optional[bool] = None,
1236
+ output_hidden_states: Optional[bool] = None,
1237
+ return_dict: Optional[bool] = None,
1238
+ ) -> Union[tuple, ImageClassifierOutput]:
1239
+ r"""
1240
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1241
+ Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
1242
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1243
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1244
+ """
1245
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1246
+ output_hidden_states = (
1247
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1248
+ )
1249
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1250
+
1251
+ outputs = self.vision_model(
1252
+ pixel_values,
1253
+ output_attentions=output_attentions,
1254
+ output_hidden_states=output_hidden_states,
1255
+ return_dict=return_dict,
1256
+ )
1257
+
1258
+ sequence_output = outputs[0]
1259
+
1260
+ # average pool the patch tokens
1261
+ sequence_output = torch.mean(sequence_output[:, 1:, :], dim=1)
1262
+ # apply classifier
1263
+ logits = self.classifier(sequence_output)
1264
+
1265
+ loss = None
1266
+ if labels is not None:
1267
+ # move labels to correct device to enable model parallelism
1268
+ labels = labels.to(logits.device)
1269
+ if self.config.problem_type is None:
1270
+ if self.num_labels == 1:
1271
+ self.config.problem_type = "regression"
1272
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1273
+ self.config.problem_type = "single_label_classification"
1274
+ else:
1275
+ self.config.problem_type = "multi_label_classification"
1276
+
1277
+ if self.config.problem_type == "regression":
1278
+ loss_fct = MSELoss()
1279
+ if self.num_labels == 1:
1280
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
1281
+ else:
1282
+ loss = loss_fct(logits, labels)
1283
+ elif self.config.problem_type == "single_label_classification":
1284
+ loss_fct = CrossEntropyLoss()
1285
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1286
+ elif self.config.problem_type == "multi_label_classification":
1287
+ loss_fct = BCEWithLogitsLoss()
1288
+ loss = loss_fct(logits, labels)
1289
+
1290
+ if not return_dict:
1291
+ output = (logits,) + outputs[2:]
1292
+ return ((loss,) + output) if loss is not None else output
1293
+
1294
+ return ImageClassifierOutput(
1295
+ loss=loss,
1296
+ logits=logits,
1297
+ hidden_states=outputs.hidden_states,
1298
+ attentions=outputs.attentions,
1299
+ )
modified_xtuner/xtuner/dataset/huggingface.py ADDED
@@ -0,0 +1,316 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import logging
3
+ import os
4
+ from datetime import timedelta
5
+ from functools import partial
6
+
7
+ import numpy as np
8
+ from datasets import DatasetDict, concatenate_datasets
9
+ from mmengine import print_log
10
+ from mmengine.config import Config, ConfigDict
11
+ from mmengine.utils.misc import get_object_from_string
12
+ from torch import distributed as dist
13
+
14
+ from xtuner.registry import BUILDER, MAP_FUNC
15
+ from .utils import Packer, encode_fn
16
+
17
+
18
+ def get_lengths(example):
19
+ return {'length': len(example['input_ids'])}
20
+
21
+
22
+ def build_origin_dataset(dataset, split):
23
+ if isinstance(dataset, DatasetDict):
24
+ if split is None:
25
+ dataset = concatenate_datasets(dataset.values())
26
+ else:
27
+ dataset = dataset[split]
28
+ elif isinstance(dataset, dict) or isinstance(
29
+ dataset, Config) or isinstance(dataset, ConfigDict):
30
+ dataset = BUILDER.build(dataset)
31
+ if isinstance(dataset, DatasetDict):
32
+ if split is None:
33
+ dataset = concatenate_datasets(dataset.values())
34
+ else:
35
+ dataset = dataset[split]
36
+ return dataset
37
+
38
+
39
+ def map_dataset(dataset, dataset_map_fn, map_num_proc):
40
+ if isinstance(dataset_map_fn, str):
41
+ map_fn_obj = MAP_FUNC.get(dataset_map_fn) or get_object_from_string(
42
+ dataset_map_fn)
43
+ if map_fn_obj is not None:
44
+ dataset_map_fn = map_fn_obj
45
+ else:
46
+ raise TypeError('dataset_map_fn must be a function or a '
47
+ "registered function's string in MAP_FUNC, "
48
+ f"but got a string of '{dataset_map_fn}'")
49
+
50
+ dataset = dataset.map(dataset_map_fn, num_proc=map_num_proc)
51
+ return dataset
52
+
53
+
54
+ def add_template_to_dataset(dataset, template_map_fn, map_num_proc):
55
+ if isinstance(template_map_fn,
56
+ dict) or isinstance(template_map_fn, Config) or isinstance(
57
+ template_map_fn, ConfigDict):
58
+ template_map_fn = BUILDER.build(template_map_fn)
59
+ dataset = dataset.map(template_map_fn, num_proc=map_num_proc)
60
+ # remove invalid data
61
+ dataset = dataset.filter(
62
+ lambda example: len(example['conversation']) > 0,
63
+ num_proc=map_num_proc)
64
+ return dataset
65
+
66
+
67
+ def tokenize_dataset(dataset, tokenizer, max_length, with_image_token,
68
+ input_ids_with_output, remove_unused_columns,
69
+ map_num_proc):
70
+ assert (tokenizer is not None) and (max_length is not None), \
71
+ f'({tokenizer}, {max_length})'
72
+ if isinstance(tokenizer, dict) or isinstance(
73
+ tokenizer, Config) or isinstance(tokenizer, ConfigDict):
74
+ tokenizer = BUILDER.build(tokenizer)
75
+ dataset = dataset.map(
76
+ partial(
77
+ encode_fn,
78
+ tokenizer=tokenizer,
79
+ max_length=max_length,
80
+ with_image_token=with_image_token,
81
+ input_ids_with_output=input_ids_with_output),
82
+ remove_columns=list(dataset.column_names)
83
+ if remove_unused_columns else None,
84
+ num_proc=map_num_proc)
85
+ return dataset
86
+
87
+
88
+ def pack_dataset(dataset, max_length, use_varlen_attn, shuffle_before_pack,
89
+ map_num_proc):
90
+ if shuffle_before_pack:
91
+ dataset = dataset.shuffle()
92
+ dataset = dataset.flatten_indices(num_proc=map_num_proc)
93
+ dataset = dataset.map(
94
+ Packer(max_length, use_varlen_attn=use_varlen_attn),
95
+ batched=True,
96
+ num_proc=map_num_proc)
97
+ return dataset
98
+
99
+
100
+ def process(dataset,
101
+ do_dataset_tokenization=True,
102
+ tokenizer=None,
103
+ max_length=None,
104
+ dataset_map_fn=None,
105
+ template_map_fn=None,
106
+ max_dataset_length=None,
107
+ split='train',
108
+ remove_unused_columns=False,
109
+ rename_maps=[],
110
+ shuffle_before_pack=True,
111
+ pack_to_max_length=True,
112
+ use_varlen_attn=False,
113
+ input_ids_with_output=True,
114
+ with_image_token=False,
115
+ map_num_proc=32):
116
+ """Post-process the dataset loaded from the Hugging Face Hub, or a local
117
+ dataset.
118
+
119
+ Args:
120
+ dataset: The dataset to be post-processed.
121
+ do_dataset_tokenization: Whether the dataset need to be tokenized
122
+ in this function. Default to True.
123
+ tokenizer: The tokenizer processes some raw text as input and outputs
124
+ an Encoding. If `do_dataset_tokenization` is True, this argument
125
+ should not be None. Default to None.
126
+ max_length: Max length of the sequence. If `do_dataset_tokenization`
127
+ or `pack_to_max_length` is True, this argument should not be None.
128
+ Default to None.
129
+ dataset_map_fn: Map the original dataset format to the one defined
130
+ by xTuner.
131
+ template_map_fn: Add the prompt template to the dataset
132
+ max_dataset_length: If the length of the dataset is too long, we can
133
+ randomly extract `max_dataset_length` from it.
134
+ split: Which split of the data to load.
135
+ If `None`, will return a single concatenated dataset with all
136
+ splits (typically `datasets.Split.TRAIN` and
137
+ `datasets.Split.TEST`).
138
+ If given, will return a single Dataset.
139
+ remove_unused_columns: Whether to remove columns from the dataset
140
+ that are not used during training.
141
+ rename_maps: Rename the column name of the dataset.
142
+ shuffle_before_pack: Whether to shuffle the dataset before
143
+ packing them.
144
+ pack_to_max_length: Whether to pack the dataset to the `max_length `.
145
+ This usually improves gpu utilization and therefore reduces
146
+ training time.
147
+ use_varlen_attn: If use_varlen_attn is True, we calculate attention
148
+ the actual length of the sequence rather than the actual length
149
+ of the sequence
150
+ input_ids_with_output: Whether to put the groundtruth output
151
+ corresponding to the question into the dataset. Typically set
152
+ it to True during training and False during testing.
153
+ with_image_token: Whether to convert DEFAULT_IMAGE_TOKEN to
154
+ IMAGE_TOKEN_INDEX. Typically set it to True during the training
155
+ of VLM.
156
+ map_num_proc: Max number of processes when mapping the dataset.
157
+ """
158
+ if use_varlen_attn:
159
+ assert pack_to_max_length, \
160
+ '`pack_to_max_length` in `process_hf_dataset` should be set to ' \
161
+ 'True if `use_varlen_attn` is True.'
162
+ if pack_to_max_length:
163
+ assert split == 'train' or split is None, \
164
+ ('`split` should be `train` or `None` if `pack_to_max_length` is '
165
+ f'True, but got {split}.')
166
+
167
+ dataset = build_origin_dataset(dataset, split)
168
+
169
+ # sample `max_dataset_length` items from the original dataset to
170
+ # save time consumed by map function
171
+ if max_dataset_length is not None:
172
+ max_dataset_length = min(max_dataset_length, len(dataset))
173
+ indices = np.random.choice(
174
+ len(dataset), max_dataset_length, replace=False)
175
+ dataset = dataset.select(indices)
176
+
177
+ # Extract the useful data for training from the original dataset.
178
+ if dataset_map_fn is not None:
179
+ dataset = map_dataset(dataset, dataset_map_fn, map_num_proc)
180
+
181
+ # Add prompt template, such as <|System|>: xxx <|User|>: xxx <|Bot|>: xxx
182
+ if template_map_fn is not None:
183
+ dataset = add_template_to_dataset(dataset, template_map_fn,
184
+ map_num_proc)
185
+
186
+ for old, new in rename_maps:
187
+ dataset = dataset.rename_column(old, new)
188
+
189
+ # remove unused columns
190
+ if pack_to_max_length and (not remove_unused_columns):
191
+ print_log(
192
+ 'We have to remove unused columns if '
193
+ '`pack_to_max_length` is set to True.',
194
+ logger='current',
195
+ level=logging.WARNING)
196
+ remove_unused_columns = True
197
+
198
+ if do_dataset_tokenization:
199
+ dataset = tokenize_dataset(dataset, tokenizer, max_length,
200
+ with_image_token, input_ids_with_output,
201
+ remove_unused_columns, map_num_proc)
202
+ else:
203
+ assert {'input_ids', 'labels'}.issubset(dataset.column_names)
204
+
205
+ if input_ids_with_output:
206
+ # remove data that does not have the valid labels.
207
+ dataset = dataset.filter(
208
+ lambda example: any(label >= 0 for label in example['labels']),
209
+ num_proc=map_num_proc)
210
+
211
+ # pack to max length
212
+ if pack_to_max_length:
213
+ dataset = pack_dataset(dataset, max_length, use_varlen_attn,
214
+ shuffle_before_pack, map_num_proc)
215
+
216
+ # add 'length'
217
+ dataset = dataset.map(get_lengths, num_proc=map_num_proc)
218
+ setattr(dataset, 'length', dataset['length'])
219
+
220
+ return dataset
221
+
222
+
223
+ def process_hf_dataset(dataset,
224
+ do_dataset_tokenization=True,
225
+ tokenizer=None,
226
+ max_length=None,
227
+ dataset_map_fn=None,
228
+ template_map_fn=None,
229
+ max_dataset_length=None,
230
+ split='train',
231
+ remove_unused_columns=False,
232
+ rename_maps=[],
233
+ shuffle_before_pack=True,
234
+ pack_to_max_length=True,
235
+ use_varlen_attn=False,
236
+ input_ids_with_output=True,
237
+ with_image_token=False,
238
+ map_num_proc=4):
239
+ """Post-process the dataset loaded from the Hugging Face Hub, or a local
240
+ dataset.
241
+
242
+ Args:
243
+ dataset: The dataset to be post-processed.
244
+ do_dataset_tokenization: Whether the dataset need to be tokenized
245
+ in this function. Default to True.
246
+ tokenizer: The tokenizer processes some raw text as input and outputs
247
+ an Encoding. If `do_dataset_tokenization` is True, this argument
248
+ should not be None. Default to None.
249
+ max_length: Max length of the sequence. If `do_dataset_tokenization`
250
+ or `pack_to_max_length` is True, this argument should not be None.
251
+ Default to None.
252
+ dataset_map_fn: Map the original dataset format to the one defined
253
+ by xTuner.
254
+ template_map_fn: Add the prompt template to the dataset
255
+ max_dataset_length: If the length of the dataset is too long, we can
256
+ randomly extract `max_dataset_length` from it.
257
+ split: Which split of the data to load.
258
+ If `None`, will return a single concatenated dataset with all
259
+ splits (typically `datasets.Split.TRAIN` and
260
+ `datasets.Split.TEST`).
261
+ If given, will return a single Dataset.
262
+ remove_unused_columns: Whether to remove columns from the dataset
263
+ that are not used during training.
264
+ rename_maps: Rename the column name of the dataset.
265
+ shuffle_before_pack: Whether to shuffle the dataset before
266
+ packing them.
267
+ pack_to_max_length: Whether to pack the dataset to the `max_length `.
268
+ This usually improves gpu utilization and therefore reduces
269
+ training time.
270
+ use_varlen_attn: If use_varlen_attn is True, we calculate attention
271
+ the actual length of the sequence rather than the actual length
272
+ of the sequence
273
+ input_ids_with_output: Whether to put the groundtruth output
274
+ corresponding to the question into the dataset. Typically set
275
+ it to True during training and False during testing.
276
+ with_image_token: Whether to convert DEFAULT_IMAGE_TOKEN to
277
+ IMAGE_TOKEN_INDEX. Typically set it to True during the training
278
+ of VLM.
279
+ map_num_proc: Max number of processes when mapping the dataset.
280
+ """
281
+ kwargs = dict(
282
+ dataset=dataset,
283
+ do_dataset_tokenization=do_dataset_tokenization,
284
+ tokenizer=tokenizer,
285
+ max_length=max_length,
286
+ dataset_map_fn=dataset_map_fn,
287
+ template_map_fn=template_map_fn,
288
+ max_dataset_length=max_dataset_length,
289
+ split=split,
290
+ remove_unused_columns=remove_unused_columns,
291
+ rename_maps=rename_maps,
292
+ shuffle_before_pack=shuffle_before_pack,
293
+ pack_to_max_length=pack_to_max_length,
294
+ use_varlen_attn=use_varlen_attn,
295
+ input_ids_with_output=input_ids_with_output,
296
+ with_image_token=with_image_token,
297
+ map_num_proc=map_num_proc)
298
+ if not (dist.is_available() and dist.is_initialized()):
299
+ return process(**kwargs)
300
+
301
+ xtuner_dataset_timeout = timedelta(
302
+ minutes=int(os.getenv('XTUNER_DATASET_TIMEOUT', default=30)))
303
+ print_log(
304
+ f'xtuner_dataset_timeout = {xtuner_dataset_timeout}', logger='current')
305
+ # monitored barrier requires gloo process group to perform host-side sync.
306
+ group_gloo = dist.new_group(backend='gloo', timeout=xtuner_dataset_timeout)
307
+
308
+ if dist.get_rank() == 0:
309
+ dataset = process(**kwargs)
310
+ objects = [dataset]
311
+ else:
312
+ objects = [None]
313
+
314
+ dist.monitored_barrier(group=group_gloo, timeout=xtuner_dataset_timeout)
315
+ dist.broadcast_object_list(objects, src=0)
316
+ return objects[0]
modified_xtuner/xtuner/dataset/llava.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import json
3
+ import os
4
+
5
+ import torch
6
+ from datasets import Dataset as HFDataset
7
+ from datasets import DatasetDict
8
+ from mmengine.config import Config, ConfigDict
9
+ from PIL import Image
10
+ from torch.utils.data import Dataset
11
+
12
+ from xtuner.registry import BUILDER
13
+ from .huggingface import process_hf_dataset
14
+ from .utils import expand2square
15
+
16
+
17
+ class LLaVADataset(Dataset):
18
+
19
+ def __init__(self,
20
+ data_path,
21
+ image_folder,
22
+ tokenizer,
23
+ image_processor,
24
+ max_dataset_length=None,
25
+ dataset_map_fn=None,
26
+ template_map_fn=None,
27
+ max_length=2048,
28
+ pad_image_to_square=False):
29
+ super().__init__()
30
+
31
+ json_data = json.load(open(data_path))
32
+ for idx in range(len(json_data)):
33
+ if isinstance(json_data[idx]['id'], int):
34
+ json_data[idx]['id'] = str(json_data[idx]['id'])
35
+ json_data = DatasetDict({'train': HFDataset.from_list(json_data)})
36
+ self.text_data = process_hf_dataset(
37
+ dataset=json_data,
38
+ tokenizer=tokenizer,
39
+ max_length=max_length,
40
+ dataset_map_fn=dataset_map_fn,
41
+ template_map_fn=template_map_fn,
42
+ split='train',
43
+ max_dataset_length=max_dataset_length,
44
+ remove_unused_columns=False,
45
+ pack_to_max_length=False,
46
+ with_image_token=True)
47
+
48
+ self.image_folder = image_folder
49
+ if isinstance(image_processor, dict) or isinstance(
50
+ image_processor, Config) or isinstance(image_processor,
51
+ ConfigDict):
52
+ self.image_processor = BUILDER.build(image_processor)
53
+ else:
54
+ self.image_processor = image_processor
55
+ self.pad_image_to_square = pad_image_to_square
56
+
57
+ @property
58
+ def modality_length(self):
59
+ length_list = []
60
+ for data_dict in self.text_data:
61
+ cur_len = len(data_dict['input_ids'])
62
+ if data_dict.get('image', None) is None:
63
+ cur_len = -cur_len
64
+ length_list.append(cur_len)
65
+ return length_list
66
+
67
+ def __len__(self):
68
+ return len(self.text_data)
69
+
70
+ def __getitem__(self, index):
71
+ data_dict = self.text_data[index]
72
+ if data_dict.get('image', None) is not None:
73
+ image_file = data_dict['image']
74
+ image = Image.open(os.path.join(self.image_folder,
75
+ image_file)).convert('RGB')
76
+ if self.pad_image_to_square:
77
+ image = expand2square(
78
+ image,
79
+ tuple(
80
+ int(x * 255) for x in self.image_processor.image_mean))
81
+ image = self.image_processor.preprocess(
82
+ image, return_tensors='pt')['pixel_values'][0]
83
+ data_dict['pixel_values'] = image
84
+ else:
85
+ size = self.image_processor.size
86
+ data_dict['pixel_values'] = torch.zeros(3, size['height'],
87
+ size['width'])
88
+ return data_dict
modified_xtuner/xtuner/tools/chat.py ADDED
@@ -0,0 +1,491 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import argparse
3
+ import os
4
+ import os.path as osp
5
+ import re
6
+ import sys
7
+
8
+ import torch
9
+ from huggingface_hub import snapshot_download
10
+ from peft import PeftModel
11
+ from transformers import (AutoModel, AutoModelForCausalLM, AutoTokenizer,
12
+ BitsAndBytesConfig, SiglipImageProcessor,
13
+ SiglipVisionModel, GenerationConfig)
14
+ from transformers.generation.streamers import TextStreamer
15
+
16
+ from xtuner.dataset.utils import expand2square, load_image
17
+ from xtuner.model.utils import prepare_inputs_labels_for_multimodal
18
+ from xtuner.tools.utils import get_stop_criteria
19
+ from xtuner.utils import (DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX,
20
+ PROMPT_TEMPLATE, SYSTEM_TEMPLATE)
21
+
22
+ TORCH_DTYPE_MAP = dict(
23
+ fp16=torch.float16, bf16=torch.bfloat16, fp32=torch.float32, auto='auto')
24
+
25
+
26
+ def remove_prefix(state_dict, prefix):
27
+ new_state_dict = {}
28
+ for key, value in state_dict.items():
29
+ if key.startswith(prefix):
30
+ new_key = key[len(prefix):]
31
+ new_state_dict[new_key] = value
32
+ else:
33
+ new_state_dict[key] = value
34
+ return new_state_dict
35
+
36
+
37
+ def parse_args():
38
+ parser = argparse.ArgumentParser(description='Chat with a HF model')
39
+ parser.add_argument(
40
+ 'model_name_or_path', help='Hugging Face model name or path')
41
+ adapter_group = parser.add_mutually_exclusive_group()
42
+ adapter_group.add_argument(
43
+ '--adapter', default=None, help='adapter name or path')
44
+ adapter_group.add_argument(
45
+ '--llava', default=None, help='llava name or path')
46
+ parser.add_argument(
47
+ '--visual-encoder', default=None, help='visual encoder name or path')
48
+ parser.add_argument(
49
+ '--visual-select-layer', default=-2, help='visual select layer')
50
+ parser.add_argument('--image', default=None, help='image')
51
+ parser.add_argument(
52
+ '--torch-dtype',
53
+ default='fp16',
54
+ choices=TORCH_DTYPE_MAP.keys(),
55
+ help='Override the default `torch.dtype` and load the model under '
56
+ 'a specific `dtype`.')
57
+ parser.add_argument(
58
+ '--prompt-template',
59
+ choices=PROMPT_TEMPLATE.keys(),
60
+ default=None,
61
+ help='Specify a prompt template')
62
+ system_group = parser.add_mutually_exclusive_group()
63
+ system_group.add_argument(
64
+ '--system', default=None, help='Specify the system text')
65
+ system_group.add_argument(
66
+ '--system-template',
67
+ choices=SYSTEM_TEMPLATE.keys(),
68
+ default=None,
69
+ help='Specify a system template')
70
+ parser.add_argument(
71
+ '--bits',
72
+ type=int,
73
+ choices=[4, 8, None],
74
+ default=None,
75
+ help='LLM bits')
76
+ parser.add_argument(
77
+ '--bot-name', type=str, default='BOT', help='Name for Bot')
78
+ parser.add_argument(
79
+ '--with-plugins',
80
+ nargs='+',
81
+ choices=['calculate', 'solve', 'search'],
82
+ help='Specify plugins to use')
83
+ parser.add_argument(
84
+ '--no-streamer', action='store_true', help='Whether to with streamer')
85
+ parser.add_argument(
86
+ '--lagent', action='store_true', help='Whether to use lagent')
87
+ parser.add_argument(
88
+ '--stop-words', nargs='+', type=str, default=[], help='Stop words')
89
+ parser.add_argument(
90
+ '--offload-folder',
91
+ default=None,
92
+ help='The folder in which to offload the model weights (or where the '
93
+ 'model weights are already offloaded).')
94
+ parser.add_argument(
95
+ '--max-new-tokens',
96
+ type=int,
97
+ default=2048,
98
+ help='Maximum number of new tokens allowed in generated text')
99
+ parser.add_argument(
100
+ '--temperature',
101
+ type=float,
102
+ default=0.1,
103
+ help='The value used to modulate the next token probabilities.')
104
+ parser.add_argument(
105
+ '--top-k',
106
+ type=int,
107
+ default=40,
108
+ help='The number of highest probability vocabulary tokens to '
109
+ 'keep for top-k-filtering.')
110
+ parser.add_argument(
111
+ '--top-p',
112
+ type=float,
113
+ default=0.75,
114
+ help='If set to float < 1, only the smallest set of most probable '
115
+ 'tokens with probabilities that add up to top_p or higher are '
116
+ 'kept for generation.')
117
+ parser.add_argument(
118
+ '--repetition-penalty',
119
+ type=float,
120
+ default=1.0,
121
+ help='The parameter for repetition penalty. 1.0 means no penalty.')
122
+ parser.add_argument(
123
+ '--seed',
124
+ type=int,
125
+ default=0,
126
+ help='Random seed for reproducible text generation')
127
+ args = parser.parse_args()
128
+ return args
129
+
130
+
131
+ def get_input():
132
+ """Helper function for getting input from users."""
133
+ sentinel = '' # ends when this string is seen
134
+ result = None
135
+ while result is None:
136
+ print(('\ndouble enter to end input (EXIT: exit chat, '
137
+ 'RESET: reset history) >>> '),
138
+ end='')
139
+ try:
140
+ result = '\n'.join(iter(input, sentinel))
141
+ except UnicodeDecodeError:
142
+ print('Invalid characters detected. Please enter again.')
143
+ return result
144
+
145
+
146
+ def main():
147
+ args = parse_args()
148
+ torch.manual_seed(args.seed)
149
+
150
+ # build llm
151
+ quantization_config = None
152
+ load_in_8bit = False
153
+ if args.bits == 4:
154
+ quantization_config = BitsAndBytesConfig(
155
+ load_in_4bit=True,
156
+ load_in_8bit=False,
157
+ llm_int8_threshold=6.0,
158
+ llm_int8_has_fp16_weight=False,
159
+ bnb_4bit_compute_dtype=torch.float16,
160
+ bnb_4bit_use_double_quant=True,
161
+ bnb_4bit_quant_type='nf4')
162
+ elif args.bits == 8:
163
+ load_in_8bit = True
164
+ model_kwargs = {
165
+ 'quantization_config': quantization_config,
166
+ 'load_in_8bit': load_in_8bit,
167
+ 'device_map': 'auto',
168
+ 'offload_folder': args.offload_folder,
169
+ 'trust_remote_code': True,
170
+ 'torch_dtype': TORCH_DTYPE_MAP[args.torch_dtype]
171
+ }
172
+ if args.lagent:
173
+ from lagent.actions import ActionExecutor, GoogleSearch
174
+ from lagent.agents import (CALL_PROTOCOL_CN, FORCE_STOP_PROMPT_CN,
175
+ ReAct, ReActProtocol)
176
+ from lagent.llms import HFTransformerCasualLM
177
+
178
+ try:
179
+ SERPER_API_KEY = os.environ['SERPER_API_KEY']
180
+ except Exception:
181
+ print('Please obtain the `SERPER_API_KEY` from https://serper.dev '
182
+ 'and set it using `export SERPER_API_KEY=xxx`.')
183
+ sys.exit(1)
184
+
185
+ model_kwargs.pop('trust_remote_code')
186
+ llm = HFTransformerCasualLM(
187
+ args.model_name_or_path, model_kwargs=model_kwargs)
188
+ if args.adapter is not None:
189
+ print(f'Loading adapter from {args.adapter}...')
190
+ llm.model = PeftModel.from_pretrained(
191
+ llm.model,
192
+ args.adapter,
193
+ offload_folder=args.offload_folder,
194
+ trust_remote_code=True)
195
+ search_tool = GoogleSearch(api_key=SERPER_API_KEY)
196
+ chatbot = ReAct(
197
+ llm=llm,
198
+ action_executor=ActionExecutor(actions=[search_tool]),
199
+ protocol=ReActProtocol(
200
+ call_protocol=CALL_PROTOCOL_CN,
201
+ force_stop=FORCE_STOP_PROMPT_CN))
202
+ while True:
203
+ text = get_input()
204
+ while text.strip() == 'RESET':
205
+ print('Log: History responses have been removed!')
206
+ chatbot._session_history = []
207
+ inputs = ''
208
+ text = get_input()
209
+ if text.strip() == 'EXIT':
210
+ print('Log: Exit!')
211
+ exit(0)
212
+ response = chatbot.chat(text)
213
+ print(response.response)
214
+ else:
215
+ if args.with_plugins is None:
216
+ inner_thoughts_open = False
217
+ calculate_open = False
218
+ solve_open = False
219
+ search_open = False
220
+ else:
221
+ assert args.prompt_template == args.system_template == 'moss_sft'
222
+ from plugins import plugins_api
223
+ inner_thoughts_open = True
224
+ calculate_open = 'calculate' in args.with_plugins
225
+ solve_open = 'solve' in args.with_plugins
226
+ search_open = 'search' in args.with_plugins
227
+ # pre-import for api and model preparation
228
+ if calculate_open:
229
+ from plugins import calculate # noqa: F401
230
+ if solve_open:
231
+ from plugins import solve # noqa: F401
232
+ if search_open:
233
+ from plugins import search # noqa: F401
234
+ # build llm
235
+ llm = AutoModelForCausalLM.from_pretrained(args.model_name_or_path,
236
+ **model_kwargs)
237
+ tokenizer = AutoTokenizer.from_pretrained(
238
+ args.model_name_or_path,
239
+ trust_remote_code=True,
240
+ encode_special_tokens=True)
241
+ print(f'Load LLM from {args.model_name_or_path}')
242
+ if args.adapter is not None:
243
+ llm = PeftModel.from_pretrained(
244
+ llm,
245
+ args.adapter,
246
+ offload_folder=args.offload_folder,
247
+ trust_remote_code=True)
248
+ print(f'Load adapter from {args.adapter}')
249
+ if args.llava is not None:
250
+ llava_path = snapshot_download(
251
+ repo_id=args.llava) if not osp.isdir(
252
+ args.llava) else args.llava
253
+
254
+ # build visual_encoder
255
+ if 'visual_encoder' in os.listdir(llava_path):
256
+ assert args.visual_encoder is None, (
257
+ "Please don't specify the `--visual-encoder` since passed "
258
+ '`--llava` contains a visual encoder!')
259
+ visual_encoder_path = osp.join(llava_path, 'visual_encoder')
260
+ else:
261
+ assert args.visual_encoder is not None, (
262
+ 'Please specify the `--visual-encoder`!')
263
+ visual_encoder_path = args.visual_encoder
264
+ visual_encoder = SiglipVisionModel.from_pretrained(
265
+ visual_encoder_path,
266
+ torch_dtype=TORCH_DTYPE_MAP[args.torch_dtype])
267
+ image_processor = SiglipImageProcessor.from_pretrained(
268
+ visual_encoder_path)
269
+ print(f'Load visual_encoder from {visual_encoder_path}')
270
+
271
+ # load adapter
272
+ if 'llm_adapter' in os.listdir(llava_path):
273
+ adapter_path = osp.join(llava_path, 'llm_adapter')
274
+ llm = PeftModel.from_pretrained(
275
+ llm,
276
+ adapter_path,
277
+ offload_folder=args.offload_folder,
278
+ trust_remote_code=True)
279
+ print(f'Load LLM adapter from {args.llava}')
280
+ if 'visual_encoder_adapter' in os.listdir(llava_path):
281
+ adapter_path = osp.join(llava_path, 'visual_encoder_adapter')
282
+ visual_encoder = PeftModel.from_pretrained(
283
+ visual_encoder,
284
+ adapter_path,
285
+ offload_folder=args.offload_folder)
286
+ print(f'Load visual_encoder adapter from {args.llava}')
287
+
288
+ # build projector
289
+ projector_path = osp.join(llava_path, 'projector')
290
+ projector = AutoModel.from_pretrained(
291
+ projector_path,
292
+ torch_dtype=TORCH_DTYPE_MAP[args.torch_dtype],
293
+ trust_remote_code=True)
294
+ print(f'Load projector from {args.llava}')
295
+
296
+ projector.cuda()
297
+ projector.eval()
298
+ visual_encoder.cuda()
299
+ visual_encoder.eval()
300
+
301
+ llm.eval()
302
+
303
+ if args.image is not None:
304
+ image = load_image(args.image)
305
+ image = expand2square(
306
+ image, tuple(int(x * 255) for x in image_processor.image_mean))
307
+ image = image_processor.preprocess(
308
+ image, return_tensors='pt')['pixel_values'][0]
309
+ image = image.cuda().unsqueeze(0)
310
+ visual_outputs = visual_encoder(image, output_hidden_states=True)
311
+ pixel_values = projector(
312
+ visual_outputs.hidden_states[args.visual_select_layer][:, 1:])
313
+
314
+ stop_words = args.stop_words
315
+ sep = ''
316
+ if args.prompt_template:
317
+ template = PROMPT_TEMPLATE[args.prompt_template]
318
+ stop_words += template.get('STOP_WORDS', [])
319
+ sep = template.get('SEP', '')
320
+ stop_criteria = get_stop_criteria(
321
+ tokenizer=tokenizer, stop_words=stop_words)
322
+
323
+ if args.no_streamer:
324
+ streamer = None
325
+ else:
326
+ streamer = TextStreamer(tokenizer, skip_prompt=True)
327
+
328
+ gen_config = GenerationConfig(
329
+ max_new_tokens=args.max_new_tokens,
330
+ do_sample=args.temperature > 0,
331
+ temperature=args.temperature,
332
+ top_p=args.top_p,
333
+ top_k=args.top_k,
334
+ repetition_penalty=args.repetition_penalty,
335
+ eos_token_id=tokenizer.eos_token_id,
336
+ pad_token_id=tokenizer.pad_token_id
337
+ if tokenizer.pad_token_id is not None else tokenizer.eos_token_id,
338
+ )
339
+
340
+ n_turn = 0
341
+ inputs = ''
342
+ while True:
343
+ text = get_input()
344
+ while text.strip() == 'RESET':
345
+ print('Log: History responses have been removed!')
346
+ n_turn = 0
347
+ inputs = ''
348
+ text = get_input()
349
+ if text.strip() == 'EXIT':
350
+ print('Log: Exit!')
351
+ exit(0)
352
+
353
+ if args.image is not None and n_turn == 0:
354
+ text = DEFAULT_IMAGE_TOKEN + '\n' + text
355
+
356
+ if args.prompt_template:
357
+ prompt_text = ''
358
+ template = PROMPT_TEMPLATE[args.prompt_template]
359
+ if 'SYSTEM' in template and n_turn == 0:
360
+ system_text = None
361
+ if args.system_template is not None:
362
+ system_text = SYSTEM_TEMPLATE[
363
+ args.system_template].format(
364
+ round=n_turn + 1, bot_name=args.bot_name)
365
+ elif args.system is not None:
366
+ system_text = args.system
367
+ if system_text is not None:
368
+ prompt_text += template['SYSTEM'].format(
369
+ system=system_text,
370
+ round=n_turn + 1,
371
+ bot_name=args.bot_name)
372
+ prompt_text += template['INSTRUCTION'].format(
373
+ input=text, round=n_turn + 1, bot_name=args.bot_name)
374
+ if args.prompt_template == args.system_template == 'moss_sft':
375
+ if not inner_thoughts_open:
376
+ prompt_text.replace('- Inner thoughts: enabled.',
377
+ '- Inner thoughts: disabled.')
378
+ if not calculate_open:
379
+ prompt_text.replace(('- Calculator: enabled. API: '
380
+ 'Calculate(expression)'),
381
+ '- Calculator: disabled.')
382
+ if not solve_open:
383
+ prompt_text.replace(
384
+ '- Equation solver: enabled. API: Solve(equation)',
385
+ '- Equation solver: disabled.')
386
+ if not search_open:
387
+ prompt_text.replace(
388
+ '- Web search: enabled. API: Search(query)',
389
+ '- Web search: disabled.')
390
+ else:
391
+ prompt_text = text
392
+ inputs += prompt_text
393
+ if args.image is None:
394
+ if n_turn == 0:
395
+ ids = tokenizer.encode(inputs, return_tensors='pt')
396
+ else:
397
+ ids = tokenizer.encode(
398
+ inputs, return_tensors='pt', add_special_tokens=False)
399
+
400
+ if args.with_plugins is not None:
401
+ generate_output = llm.generate(
402
+ inputs=ids.cuda(),
403
+ generation_config=gen_config,
404
+ streamer=streamer,
405
+ stopping_criteria=stop_criteria).cpu()
406
+ generate_output_text = tokenizer.decode(
407
+ generate_output[0][len(ids[0]):])
408
+ if streamer is None:
409
+ end = '' if generate_output_text[-1] == '\n' else '\n'
410
+ print(generate_output_text, end=end)
411
+ pattern = r'<\|Commands\|>:(.*?)<eoc>'
412
+ command_text = ', '.join(
413
+ re.findall(pattern, generate_output_text))
414
+ extent_text = plugins_api(
415
+ command_text,
416
+ calculate_open=calculate_open,
417
+ solve_open=solve_open,
418
+ search_open=search_open)
419
+ end = '' if extent_text[-1] == '\n' else '\n'
420
+ print(extent_text, end=end)
421
+ extent_text_ids = tokenizer.encode(
422
+ extent_text,
423
+ return_tensors='pt',
424
+ add_special_tokens=False)
425
+ new_ids = torch.cat((generate_output, extent_text_ids),
426
+ dim=1)
427
+
428
+ generate_output = llm.generate(
429
+ inputs=new_ids.cuda(),
430
+ generation_config=gen_config,
431
+ streamer=streamer,
432
+ stopping_criteria=stop_criteria)
433
+ if streamer is None:
434
+ output_text = tokenizer.decode(
435
+ generate_output[0][len(new_ids[0]):])
436
+ end = '' if output_text[-1] == '\n' else '\n'
437
+ print(output_text, end=end)
438
+ else:
439
+ generate_output = llm.generate(
440
+ inputs=ids.cuda(),
441
+ generation_config=gen_config,
442
+ streamer=streamer,
443
+ stopping_criteria=stop_criteria)
444
+ if streamer is None:
445
+ output_text = tokenizer.decode(
446
+ generate_output[0][len(ids[0]):])
447
+ end = '' if output_text[-1] == '\n' else '\n'
448
+ print(output_text, end=end)
449
+ inputs = tokenizer.decode(generate_output[0])
450
+ else:
451
+ chunk_encode = []
452
+ for idx, chunk in enumerate(inputs.split(DEFAULT_IMAGE_TOKEN)):
453
+ if idx == 0 and n_turn == 0:
454
+ cur_encode = tokenizer.encode(chunk)
455
+ else:
456
+ cur_encode = tokenizer.encode(
457
+ chunk, add_special_tokens=False)
458
+ chunk_encode.append(cur_encode)
459
+ assert len(chunk_encode) == 2
460
+ ids = []
461
+ for idx, cur_chunk_encode in enumerate(chunk_encode):
462
+ ids.extend(cur_chunk_encode)
463
+ if idx != len(chunk_encode) - 1:
464
+ ids.append(IMAGE_TOKEN_INDEX)
465
+ ids = torch.tensor(ids).cuda().unsqueeze(0)
466
+ mm_inputs = prepare_inputs_labels_for_multimodal(
467
+ llm=llm, input_ids=ids, pixel_values=pixel_values)
468
+
469
+ generate_output = llm.generate(
470
+ **mm_inputs,
471
+ generation_config=gen_config,
472
+ streamer=streamer,
473
+ bos_token_id=tokenizer.bos_token_id,
474
+ stopping_criteria=stop_criteria)
475
+ if streamer is None:
476
+ output_text = tokenizer.decode(generate_output[0])
477
+ end = '' if output_text[-1] == '\n' else '\n'
478
+ print(output_text, end=end)
479
+ inputs += tokenizer.decode(generate_output[0])
480
+ n_turn += 1
481
+ inputs += sep
482
+ if len(generate_output[0]) >= args.max_new_tokens:
483
+ print(
484
+ 'Remove the memory of history responses, since '
485
+ f'it exceeds the length limitation {args.max_new_tokens}.')
486
+ n_turn = 0
487
+ inputs = ''
488
+
489
+
490
+ if __name__ == '__main__':
491
+ main()
modified_xtuner/xtuner/tools/mmbench.py ADDED
@@ -0,0 +1,510 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import argparse
3
+ import json
4
+ import math
5
+ import os
6
+ import os.path as osp
7
+ import re
8
+ import string
9
+ import time
10
+
11
+ import numpy as np
12
+ import pandas as pd
13
+ import torch
14
+ import tqdm
15
+ from huggingface_hub import snapshot_download
16
+ from mmengine import mkdir_or_exist
17
+ from mmengine.dist import (collect_results, get_dist_info, get_rank, init_dist,
18
+ master_only)
19
+ from mmengine.utils.dl_utils import set_multi_processing
20
+ from peft import PeftModel
21
+ from rich.console import Console
22
+ from rich.table import Table
23
+ from torch.utils.data import Dataset
24
+ from transformers import (AutoModel, AutoModelForCausalLM, AutoTokenizer,
25
+ BitsAndBytesConfig, SiglipImageProcessor,
26
+ SiglipVisionModel, GenerationConfig)
27
+
28
+ from xtuner.dataset.utils import decode_base64_to_image, expand2square
29
+ from xtuner.model.utils import LoadWoInit, prepare_inputs_labels_for_multimodal
30
+ from xtuner.tools.utils import get_stop_criteria, is_cn_string
31
+ from xtuner.utils import (DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX,
32
+ PROMPT_TEMPLATE)
33
+
34
+ TORCH_DTYPE_MAP = dict(
35
+ fp16=torch.float16, bf16=torch.bfloat16, fp32=torch.float32, auto='auto')
36
+
37
+
38
+ def parse_args():
39
+ parser = argparse.ArgumentParser(description='MMBench')
40
+ parser.add_argument(
41
+ 'model_name_or_path', help='Hugging Face model name or path')
42
+ parser.add_argument('--data-path', default=None, help='data path')
43
+ parser.add_argument('--work-dir', help='the dir to save results')
44
+ parser.add_argument('--llava', default=None, help='llava name or path')
45
+ parser.add_argument(
46
+ '--visual-encoder', default=None, help='visual encoder name or path')
47
+ parser.add_argument(
48
+ '--visual-select-layer', default=-2, help='visual select layer')
49
+ parser.add_argument(
50
+ '--prompt-template',
51
+ choices=PROMPT_TEMPLATE.keys(),
52
+ default=None,
53
+ help='Specify a prompt template')
54
+ parser.add_argument(
55
+ '--stop-words', nargs='+', type=str, default=[], help='Stop words')
56
+ parser.add_argument(
57
+ '--torch-dtype',
58
+ default='fp16',
59
+ choices=TORCH_DTYPE_MAP.keys(),
60
+ help='Override the default `torch.dtype` and load the model under '
61
+ 'a specific `dtype`.')
62
+ parser.add_argument(
63
+ '--bits',
64
+ type=int,
65
+ choices=[4, 8, None],
66
+ default=None,
67
+ help='LLM bits')
68
+ parser.add_argument(
69
+ '--bot-name', type=str, default='BOT', help='Name for Bot')
70
+ parser.add_argument(
71
+ '--offload-folder',
72
+ default=None,
73
+ help='The folder in which to offload the model weights (or where the '
74
+ 'model weights are already offloaded).')
75
+ parser.add_argument(
76
+ '--max-new-tokens',
77
+ type=int,
78
+ default=100,
79
+ help='Maximum number of new tokens allowed in generated text')
80
+ parser.add_argument(
81
+ '--seed',
82
+ type=int,
83
+ default=0,
84
+ help='Random seed for reproducible text generation')
85
+ parser.add_argument(
86
+ '--launcher',
87
+ choices=['none', 'pytorch', 'slurm', 'mpi'],
88
+ default='none',
89
+ help='job launcher')
90
+ args = parser.parse_args()
91
+ return args
92
+
93
+
94
+ @master_only
95
+ def master_print(msg):
96
+ print(msg)
97
+
98
+
99
+ class MMBenchDataset(Dataset):
100
+ ABBRS = {
101
+ 'coarse_perception': 'CP',
102
+ 'finegrained_perception (instance-level)': 'FP-S',
103
+ 'finegrained_perception (cross-instance)': 'FP-C',
104
+ 'logic_reasoning': 'LR',
105
+ 'relation_reasoning': 'RR',
106
+ 'attribute_reasoning': 'AR',
107
+ 'sketch_reasoning': 'Sketch Reasoning',
108
+ 'scenery_building': 'Scenery & Building',
109
+ 'food_clothes': 'Food & Clothes',
110
+ 'historical_figure': 'Historical Figure',
111
+ 'traditional_show': 'Traditional Show',
112
+ 'calligraphy_painting': 'Calligraphy Painting',
113
+ 'cultural_relic': 'Cultural Relic'
114
+ }
115
+
116
+ def __init__(self, data_file):
117
+ self.data_file = data_file
118
+ self.df = pd.read_csv(data_file, sep='\t')
119
+ self.split = 'dev' if 'answer' in self.df.iloc[0].keys() else 'test'
120
+ self.has_l2_category = 'l2-category' in self.df.columns.to_list()
121
+
122
+ def get_image(self, image):
123
+ while len(image) < 16:
124
+ image = self.df[self.df['index'] == int(image)]['image'].values
125
+ assert len(image) == 1
126
+ image = image[0]
127
+ image = decode_base64_to_image(image)
128
+ return image
129
+
130
+ def __len__(self):
131
+ return len(self.df)
132
+
133
+ def __getitem__(self, idx):
134
+ index = self.df.iloc[idx]['index']
135
+ image = self.df.iloc[idx]['image']
136
+ image = self.get_image(image)
137
+ question = self.df.iloc[idx]['question']
138
+ answer = self.df.iloc[idx]['answer'] if 'answer' in self.df.iloc[
139
+ 0].keys() else None
140
+ category = self.df.iloc[idx]['category']
141
+
142
+ options = {
143
+ cand: self.load_from_df(idx, cand)
144
+ for cand in string.ascii_uppercase
145
+ if self.load_from_df(idx, cand) is not None
146
+ }
147
+ options_prompt = ''
148
+ for key, item in options.items():
149
+ options_prompt += f'{key}. {item}\n'
150
+
151
+ hint = self.load_from_df(idx, 'hint')
152
+ data = {
153
+ 'img': image,
154
+ 'question': question,
155
+ 'answer': answer,
156
+ 'options': options_prompt,
157
+ 'category': category,
158
+ 'options_dict': options,
159
+ 'index': index,
160
+ 'context': hint,
161
+ }
162
+ if self.has_l2_category:
163
+ data.update({'l2-category': self.df.iloc[idx]['l2-category']})
164
+ return data
165
+
166
+ def load_from_df(self, idx, key):
167
+ if key in self.df.iloc[idx] and not pd.isna(self.df.iloc[idx][key]):
168
+ return self.df.iloc[idx][key]
169
+ else:
170
+ return None
171
+
172
+ @master_only
173
+ def eval_result(self, result_df, show=True):
174
+
175
+ def calc_acc(df, group='category'):
176
+ assert group in ['overall', 'category', 'l2-category']
177
+ if group == 'overall':
178
+ res = {'Average': np.mean(df['hit'])}
179
+ else:
180
+ res = {}
181
+ abilities = list(set(df[group]))
182
+ abilities.sort()
183
+ for ab in abilities:
184
+ sub_df = df[df[group] == ab]
185
+ ab = self.ABBRS[ab] if ab in self.ABBRS else ab
186
+ res[ab] = np.mean(sub_df['hit'])
187
+ return res
188
+
189
+ def eval_sub_data(sub_data, answer_map):
190
+ lt = len(sub_data)
191
+ for i in range(lt):
192
+ item = sub_data.iloc[i]
193
+ match = re.search(r'([A-D]+)', item['prediction'])
194
+ pred = match.group(1) if match else ''
195
+ gt = answer_map[item['index']]
196
+ if gt != pred:
197
+ return 0
198
+ return 1
199
+
200
+ def show_result(ret_json):
201
+ show_dict = ret_json.copy()
202
+ table = Table(title=f' MMBench ({self.data_file}) ')
203
+ console = Console()
204
+ table.add_column('Category', justify='left')
205
+ table.add_column('Accuracy (%)', justify='right')
206
+ average = show_dict.pop('Average') * 100
207
+ table.add_row('Average', f'{average:.1f}')
208
+ table.add_section()
209
+ for cat_name, cat_acc in show_dict.items():
210
+ table.add_row(cat_name, f'{cat_acc * 100:.1f}')
211
+ with console.capture() as capture:
212
+ console.print(table, end='')
213
+ print('\n' + capture.get())
214
+ print('Note: Please be cautious if you use the results in papers, '
215
+ "since we don't use ChatGPT as a helper for choice "
216
+ 'extraction')
217
+
218
+ data = result_df.sort_values(by='index')
219
+ data['prediction'] = [str(x) for x in data['prediction']]
220
+ for k in data.keys():
221
+ data[k.lower() if k not in 'ABCD' else k] = data.pop(k)
222
+
223
+ data_main = data[data['index'] < int(1e6)]
224
+ cate_map = {
225
+ i: c
226
+ for i, c in zip(self.df['index'], self.df['category'])
227
+ }
228
+ if self.has_l2_category:
229
+ l2_cate_map = {
230
+ i: c
231
+ for i, c in zip(self.df['index'], self.df['l2-category'])
232
+ }
233
+ answer_map = {
234
+ i: c
235
+ for i, c in zip(self.df['index'], self.df['answer'])
236
+ }
237
+
238
+ lt = len(data_main)
239
+ hit, tot = 0, 0
240
+ result = {}
241
+ for i in range(lt):
242
+ item_main = data_main.iloc[i]
243
+ idx = item_main['index']
244
+ assert idx not in result
245
+ sub_data = data[data['index'] % int(1e6) == idx]
246
+ ret = eval_sub_data(sub_data, answer_map)
247
+ result[idx] = ret
248
+ hit += ret
249
+ tot += 1
250
+
251
+ indices = data_main['index']
252
+ data_main = data_main.copy()
253
+ data_main['hit'] = [result[i] for i in indices]
254
+ main_idx = data_main['index']
255
+ data_main['category'] = [cate_map[i] for i in main_idx]
256
+
257
+ ret_json = calc_acc(data_main, 'overall')
258
+
259
+ if self.has_l2_category:
260
+ data_main['l2-category'] = [l2_cate_map[i] for i in main_idx]
261
+ l2 = calc_acc(data_main, 'l2-category')
262
+ ret_json.update(l2)
263
+ else:
264
+ leaf = calc_acc(data_main, 'category')
265
+ ret_json.update(leaf)
266
+ if show:
267
+ show_result(ret_json)
268
+ return ret_json
269
+
270
+
271
+ def main():
272
+ args = parse_args()
273
+
274
+ torch.manual_seed(args.seed)
275
+
276
+ if args.launcher != 'none':
277
+ set_multi_processing(distributed=True)
278
+ init_dist(args.launcher)
279
+
280
+ rank, world_size = get_dist_info()
281
+ torch.cuda.set_device(rank)
282
+ else:
283
+ rank = 0
284
+ world_size = 1
285
+
286
+ # build llm
287
+ quantization_config = None
288
+ load_in_8bit = False
289
+ if args.bits == 4:
290
+ quantization_config = BitsAndBytesConfig(
291
+ load_in_4bit=True,
292
+ load_in_8bit=False,
293
+ llm_int8_threshold=6.0,
294
+ llm_int8_has_fp16_weight=False,
295
+ bnb_4bit_compute_dtype=torch.float16,
296
+ bnb_4bit_use_double_quant=True,
297
+ bnb_4bit_quant_type='nf4')
298
+ elif args.bits == 8:
299
+ load_in_8bit = True
300
+ model_kwargs = {
301
+ 'quantization_config': quantization_config,
302
+ 'load_in_8bit': load_in_8bit,
303
+ 'device_map': rank if world_size > 1 else 'auto',
304
+ 'offload_folder': args.offload_folder,
305
+ 'trust_remote_code': True,
306
+ 'torch_dtype': TORCH_DTYPE_MAP[args.torch_dtype]
307
+ }
308
+
309
+ # build llm
310
+ with LoadWoInit():
311
+ llm = AutoModelForCausalLM.from_pretrained(args.model_name_or_path,
312
+ **model_kwargs)
313
+ tokenizer = AutoTokenizer.from_pretrained(
314
+ args.model_name_or_path,
315
+ trust_remote_code=True,
316
+ encode_special_tokens=True)
317
+ master_print(f'Load LLM from {args.model_name_or_path}')
318
+
319
+ llava_path = snapshot_download(
320
+ repo_id=args.llava) if not osp.isdir(args.llava) else args.llava
321
+
322
+ # build visual_encoder
323
+ if 'visual_encoder' in os.listdir(llava_path):
324
+ assert args.visual_encoder is None, (
325
+ "Please don't specify the `--visual-encoder` since passed "
326
+ '`--llava` contains a visual encoder!')
327
+ visual_encoder_path = osp.join(llava_path, 'visual_encoder')
328
+ else:
329
+ assert args.visual_encoder is not None, (
330
+ 'Please specify the `--visual-encoder`!')
331
+ visual_encoder_path = args.visual_encoder
332
+ with LoadWoInit():
333
+ visual_encoder = SiglipVisionModel.from_pretrained(
334
+ visual_encoder_path, torch_dtype=TORCH_DTYPE_MAP[args.torch_dtype])
335
+ image_processor = SiglipImageProcessor.from_pretrained(
336
+ visual_encoder_path)
337
+ master_print(f'Load visual_encoder from {visual_encoder_path}')
338
+
339
+ # load adapter
340
+ if 'llm_adapter' in os.listdir(llava_path):
341
+ adapter_path = osp.join(llava_path, 'llm_adapter')
342
+
343
+ with LoadWoInit():
344
+ llm = PeftModel.from_pretrained(
345
+ llm, adapter_path, offload_folder=args.offload_folder)
346
+
347
+ master_print(f'Load LLM adapter from {args.llava}')
348
+
349
+ if 'visual_encoder_adapter' in os.listdir(llava_path):
350
+ adapter_path = osp.join(llava_path, 'visual_encoder_adapter')
351
+ visual_encoder = PeftModel.from_pretrained(
352
+ visual_encoder, adapter_path, offload_folder=args.offload_folder)
353
+ master_print(f'Load visual_encoder adapter from {args.llava}')
354
+
355
+ # build projector
356
+ projector_path = osp.join(llava_path, 'projector')
357
+ with LoadWoInit():
358
+ projector = AutoModel.from_pretrained(
359
+ projector_path, torch_dtype=TORCH_DTYPE_MAP[args.torch_dtype])
360
+ master_print(f'Load projector from {args.llava}')
361
+
362
+ projector.cuda()
363
+ projector.eval()
364
+
365
+ visual_encoder.cuda()
366
+ visual_encoder.eval()
367
+
368
+ llm.eval()
369
+
370
+ stop_words = args.stop_words
371
+ if args.prompt_template:
372
+ template = PROMPT_TEMPLATE[args.prompt_template]
373
+ stop_words += template.get('STOP_WORDS', [])
374
+ stop_criteria = get_stop_criteria(
375
+ tokenizer=tokenizer, stop_words=stop_words)
376
+
377
+ gen_config = GenerationConfig(
378
+ max_new_tokens=args.max_new_tokens,
379
+ do_sample=False,
380
+ eos_token_id=tokenizer.eos_token_id,
381
+ pad_token_id=tokenizer.pad_token_id
382
+ if tokenizer.pad_token_id is not None else tokenizer.eos_token_id,
383
+ )
384
+
385
+ # work_dir
386
+ if args.work_dir is not None:
387
+ # update configs according to CLI args if args.work_dir is not None
388
+ save_dir = args.work_dir
389
+ else:
390
+ # use config filename as default work_dir
391
+ save_dir = osp.join('./work_dirs',
392
+ osp.splitext(osp.basename(args.data_path))[0])
393
+ timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime(time.time()))
394
+ save_dir = osp.join(save_dir, timestamp)
395
+
396
+ if rank == 0:
397
+ mkdir_or_exist(osp.abspath(save_dir))
398
+ print('=======================================================')
399
+ print(f'Dataset path: {osp.abspath(args.data_path)}\n'
400
+ f'Results will be saved to {osp.abspath(save_dir)}')
401
+ print('=======================================================')
402
+
403
+ args_path = osp.join(save_dir, 'args.json')
404
+ with open(args_path, 'w') as f:
405
+ json.dump(args.__dict__, f, indent=2)
406
+
407
+ results_xlsx_path = osp.join(save_dir, 'mmbench_result.xlsx')
408
+ results_json_path = osp.join(save_dir, 'mmbench_result.json')
409
+
410
+ dataset = MMBenchDataset(args.data_path)
411
+
412
+ results = []
413
+ n_samples = len(dataset)
414
+ per_rank_samples = math.ceil(n_samples / world_size)
415
+
416
+ per_rank_ids = range(per_rank_samples * rank,
417
+ min(n_samples, per_rank_samples * (rank + 1)))
418
+ for i in tqdm.tqdm(per_rank_ids, desc=f'Rank {rank}'):
419
+ data_sample = dataset[i]
420
+ if data_sample['context'] is not None:
421
+ text = data_sample['context'] + '\n' + data_sample[
422
+ 'question'] + '\n' + data_sample['options']
423
+ else:
424
+ text = data_sample['question'] + '\n' + data_sample['options']
425
+
426
+ text = DEFAULT_IMAGE_TOKEN + '\n' + text
427
+
428
+ if is_cn_string(text):
429
+ text = text + '请直接回答选项字母。'
430
+ else:
431
+ text = text + ("Answer with the option's letter from the "
432
+ 'given choices directly.')
433
+
434
+ if args.prompt_template:
435
+ prompt_text = ''
436
+ template = PROMPT_TEMPLATE[args.prompt_template]
437
+ prompt_text += template['INSTRUCTION'].format(
438
+ input=text, round=1, bot_name=args.bot_name)
439
+ else:
440
+ prompt_text = text
441
+ inputs = prompt_text
442
+
443
+ image = data_sample['img'].convert('RGB')
444
+ image = expand2square(
445
+ image, tuple(int(x * 255) for x in image_processor.image_mean))
446
+ image = image_processor.preprocess(
447
+ image, return_tensors='pt')['pixel_values'][0]
448
+ image = image.cuda().unsqueeze(0)
449
+ visual_outputs = visual_encoder(image, output_hidden_states=True)
450
+ pixel_values = projector(
451
+ visual_outputs.hidden_states[args.visual_select_layer][:, 1:])
452
+
453
+ chunk_encode = []
454
+ for idx, chunk in enumerate(inputs.split(DEFAULT_IMAGE_TOKEN)):
455
+ if idx == 0:
456
+ cur_encode = tokenizer.encode(chunk)
457
+ else:
458
+ cur_encode = tokenizer.encode(chunk, add_special_tokens=False)
459
+ chunk_encode.append(cur_encode)
460
+ assert len(chunk_encode) == 2
461
+ ids = []
462
+ for idx, cur_chunk_encode in enumerate(chunk_encode):
463
+ ids.extend(cur_chunk_encode)
464
+ if idx != len(chunk_encode) - 1:
465
+ ids.append(IMAGE_TOKEN_INDEX)
466
+ ids = torch.tensor(ids).cuda().unsqueeze(0)
467
+ mm_inputs = prepare_inputs_labels_for_multimodal(
468
+ llm=llm, input_ids=ids, pixel_values=pixel_values)
469
+
470
+ generate_output = llm.generate(
471
+ **mm_inputs,
472
+ generation_config=gen_config,
473
+ streamer=None,
474
+ bos_token_id=tokenizer.bos_token_id,
475
+ stopping_criteria=stop_criteria)
476
+
477
+ predict = tokenizer.decode(
478
+ generate_output[0], skip_special_tokens=True).strip()
479
+ cur_result = {}
480
+ cur_result['question'] = data_sample.get('question')
481
+ cur_result.update(data_sample.get('options_dict'))
482
+ cur_result['prediction'] = predict
483
+ if data_sample.get('category') is not None:
484
+ cur_result['category'] = data_sample.get('category')
485
+ if data_sample.get('l2-category') is not None:
486
+ cur_result['l2-category'] = data_sample.get('l2-category')
487
+ cur_result['index'] = data_sample.get('index')
488
+ cur_result['split'] = data_sample.get('split')
489
+ cur_result['answer'] = data_sample.get('answer')
490
+ results.append(cur_result)
491
+
492
+ results = collect_results(results, n_samples)
493
+
494
+ if get_rank() == 0:
495
+
496
+ results_df = pd.DataFrame(results)
497
+ with pd.ExcelWriter(results_xlsx_path, engine='openpyxl') as writer:
498
+ results_df.to_excel(writer, index=False)
499
+
500
+ if dataset.split == 'dev':
501
+ results_dict = dataset.eval_result(results_df, show=True)
502
+ with open(results_json_path, 'w') as f:
503
+ json.dump(results_dict, f, indent=2)
504
+ else:
505
+ print('All done!')
506
+
507
+
508
+ if __name__ == '__main__':
509
+
510
+ main()
pretrain.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import torch
3
+ from mmengine.dataset import DefaultSampler
4
+ from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
5
+ LoggerHook, ParamSchedulerHook)
6
+ from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
7
+ from torch.optim import AdamW
8
+ from transformers import (AutoModelForCausalLM, AutoTokenizer,
9
+ BitsAndBytesConfig, SiglipImageProcessor,
10
+ SiglipVisionModel)
11
+
12
+ from xtuner.dataset import LLaVADataset
13
+ from xtuner.dataset.collate_fns import default_collate_fn
14
+ from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory
15
+ from xtuner.engine.hooks import DatasetInfoHook, EvaluateChatHook
16
+ from xtuner.engine.runner import TrainLoop
17
+ from xtuner.model import LLaVAModel
18
+ from xtuner.utils import PROMPT_TEMPLATE
19
+
20
+ #######################################################################
21
+ # PART 1 Settings #
22
+ #######################################################################
23
+ # Model
24
+ llm_name_or_path = 'internlm/internlm2-1_8b'
25
+ visual_encoder_name_or_path = 'google/siglip-so400m-patch14-384'
26
+
27
+ # Data
28
+ data_root = './'
29
+ data_path = data_root + 'LLaVA-Pretrain/blip_laion_cc_sbu_558k.json'
30
+ image_folder = data_root + 'LLaVA-Pretrain/images'
31
+ prompt_template = PROMPT_TEMPLATE.internlm2_chat
32
+ max_length = int(2048 - (336 / 14)**2)
33
+
34
+ # Scheduler & Optimizer
35
+ batch_size = 12 # per_device
36
+ accumulative_counts = 5
37
+ dataloader_num_workers = 12
38
+ max_epochs = 1
39
+ optim_type = AdamW
40
+ lr = 1e-3
41
+ betas = (0.9, 0.999)
42
+ weight_decay = 0
43
+ max_norm = 1 # grad clip
44
+ warmup_ratio = 0.03
45
+
46
+ # Save
47
+ save_steps = 2000
48
+ save_total_limit = 2 # Maximum checkpoints to keep (-1 means unlimited)
49
+
50
+ # Evaluate the generation performance during the training
51
+ evaluation_freq = 2000
52
+ SYSTEM = ''
53
+ evaluation_images = 'https://llava-vl.github.io/static/images/view.jpg'
54
+ evaluation_inputs = ['请描述一下这张照片', 'Please describe this picture']
55
+
56
+ #######################################################################
57
+ # PART 2 Model & Tokenizer & Image Processor #
58
+ #######################################################################
59
+ tokenizer = dict(
60
+ type=AutoTokenizer.from_pretrained,
61
+ pretrained_model_name_or_path=llm_name_or_path,
62
+ trust_remote_code=True,
63
+ padding_side='right')
64
+
65
+ image_processor = dict(
66
+ type=SiglipImageProcessor.from_pretrained,
67
+ pretrained_model_name_or_path=visual_encoder_name_or_path,
68
+ trust_remote_code=True)
69
+
70
+ model = dict(
71
+ type=LLaVAModel,
72
+ freeze_llm=True,
73
+ freeze_visual_encoder=True,
74
+ llm=dict(
75
+ type=AutoModelForCausalLM.from_pretrained,
76
+ pretrained_model_name_or_path=llm_name_or_path,
77
+ trust_remote_code=True,
78
+ torch_dtype=torch.float16,
79
+ quantization_config=dict(
80
+ type=BitsAndBytesConfig,
81
+ load_in_4bit=True,
82
+ load_in_8bit=False,
83
+ llm_int8_threshold=6.0,
84
+ llm_int8_has_fp16_weight=False,
85
+ bnb_4bit_compute_dtype=torch.float16,
86
+ bnb_4bit_use_double_quant=True,
87
+ bnb_4bit_quant_type='nf4')),
88
+ visual_encoder=dict(
89
+ type=SiglipVisionModel.from_pretrained,
90
+ pretrained_model_name_or_path=visual_encoder_name_or_path))
91
+
92
+ #######################################################################
93
+ # PART 3 Dataset & Dataloader #
94
+ #######################################################################
95
+ llava_dataset = dict(
96
+ type=LLaVADataset,
97
+ data_path=data_path,
98
+ image_folder=image_folder,
99
+ tokenizer=tokenizer,
100
+ image_processor=image_processor,
101
+ dataset_map_fn=llava_map_fn,
102
+ template_map_fn=dict(
103
+ type=template_map_fn_factory, template=prompt_template),
104
+ max_length=max_length,
105
+ pad_image_to_square=False)
106
+
107
+ train_dataloader = dict(
108
+ batch_size=batch_size,
109
+ num_workers=dataloader_num_workers,
110
+ dataset=llava_dataset,
111
+ sampler=dict(type=DefaultSampler, shuffle=True),
112
+ collate_fn=dict(type=default_collate_fn))
113
+
114
+ #######################################################################
115
+ # PART 4 Scheduler & Optimizer #
116
+ #######################################################################
117
+ # optimizer
118
+ optim_wrapper = dict(
119
+ type=AmpOptimWrapper,
120
+ optimizer=dict(
121
+ type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
122
+ clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
123
+ accumulative_counts=accumulative_counts,
124
+ loss_scale='dynamic',
125
+ dtype='float16')
126
+
127
+ # learning policy
128
+ # More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501
129
+ param_scheduler = [
130
+ dict(
131
+ type=LinearLR,
132
+ start_factor=1e-5,
133
+ by_epoch=True,
134
+ begin=0,
135
+ end=warmup_ratio * max_epochs,
136
+ convert_to_iter_based=True),
137
+ dict(
138
+ type=CosineAnnealingLR,
139
+ eta_min=0.0,
140
+ by_epoch=True,
141
+ begin=warmup_ratio * max_epochs,
142
+ end=max_epochs,
143
+ convert_to_iter_based=True)
144
+ ]
145
+
146
+ # train, val, test setting
147
+ train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
148
+
149
+ #######################################################################
150
+ # PART 5 Runtime #
151
+ #######################################################################
152
+ # Log the dialogue periodically during the training process, optional
153
+ custom_hooks = [
154
+ dict(type=DatasetInfoHook, tokenizer=tokenizer),
155
+ dict(
156
+ type=EvaluateChatHook,
157
+ tokenizer=tokenizer,
158
+ image_processor=image_processor,
159
+ every_n_iters=evaluation_freq,
160
+ evaluation_inputs=evaluation_inputs,
161
+ evaluation_images=evaluation_images,
162
+ system=SYSTEM,
163
+ prompt_template=prompt_template)
164
+ ]
165
+
166
+ # configure default hooks
167
+ default_hooks = dict(
168
+ # record the time of every iteration.
169
+ timer=dict(type=IterTimerHook),
170
+ # print log every 10 iterations.
171
+ logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
172
+ # enable the parameter scheduler.
173
+ param_scheduler=dict(type=ParamSchedulerHook),
174
+ # save checkpoint per `save_steps`.
175
+ checkpoint=dict(
176
+ type=CheckpointHook,
177
+ by_epoch=False,
178
+ interval=save_steps,
179
+ max_keep_ckpts=save_total_limit),
180
+ # set sampler seed in distributed evrionment.
181
+ sampler_seed=dict(type=DistSamplerSeedHook),
182
+ )
183
+
184
+ # configure environment
185
+ env_cfg = dict(
186
+ # whether to enable cudnn benchmark
187
+ cudnn_benchmark=False,
188
+ # set multi process parameters
189
+ mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
190
+ # set distributed parameters
191
+ dist_cfg=dict(backend='nccl'),
192
+ )
193
+
194
+ # set visualizer
195
+ from mmengine.visualization import Visualizer, TensorboardVisBackend
196
+ visualizer = dict(
197
+ type=Visualizer,
198
+ vis_backends=[dict(type=TensorboardVisBackend)]
199
+ )
200
+
201
+ # set log level
202
+ log_level = 'INFO'
203
+
204
+ # load from which checkpoint
205
+ load_from = None
206
+
207
+ # whether to resume training from the loaded checkpoint
208
+ resume = False
209
+
210
+ # Defaults to use random seed and disable `deterministic`
211
+ randomness = dict(seed=None, deterministic=False)
212
+
213
+ # set log processor
214
+ log_processor = dict(by_epoch=False)