Fabrice-TIERCELIN commited on
Commit
ad4972e
·
verified ·
1 Parent(s): f76424f

Upload 7 files

Browse files
llava/model/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .language_model.llava_llama import LlavaLlamaForCausalLM, LlavaConfig
2
+ from .language_model.llava_mpt import LlavaMPTForCausalLM, LlavaMPTConfig
llava/model/apply_delta.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Usage:
3
+ python3 -m fastchat.model.apply_delta --base ~/model_weights/llama-7b --target ~/model_weights/vicuna-7b --delta lmsys/vicuna-7b-delta
4
+ """
5
+ import argparse
6
+
7
+ import torch
8
+ from tqdm import tqdm
9
+ from transformers import AutoTokenizer, AutoModelForCausalLM
10
+ from llava import LlavaLlamaForCausalLM
11
+
12
+
13
+ def apply_delta(base_model_path, target_model_path, delta_path):
14
+ print("Loading base model")
15
+ base = AutoModelForCausalLM.from_pretrained(
16
+ base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
17
+
18
+ print("Loading delta")
19
+ delta = LlavaLlamaForCausalLM.from_pretrained(delta_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
20
+ delta_tokenizer = AutoTokenizer.from_pretrained(delta_path)
21
+
22
+ print("Applying delta")
23
+ for name, param in tqdm(delta.state_dict().items(), desc="Applying delta"):
24
+ if name not in base.state_dict():
25
+ assert name in ['model.mm_projector.weight', 'model.mm_projector.bias'], f'{name} not in base model'
26
+ continue
27
+ if param.data.shape == base.state_dict()[name].shape:
28
+ param.data += base.state_dict()[name]
29
+ else:
30
+ assert name in ['model.embed_tokens.weight', 'lm_head.weight'], \
31
+ f'{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}'
32
+ bparam = base.state_dict()[name]
33
+ param.data[:bparam.shape[0], :bparam.shape[1]] += bparam
34
+
35
+ print("Saving target model")
36
+ delta.save_pretrained(target_model_path)
37
+ delta_tokenizer.save_pretrained(target_model_path)
38
+
39
+
40
+ if __name__ == "__main__":
41
+ parser = argparse.ArgumentParser()
42
+ parser.add_argument("--base-model-path", type=str, required=True)
43
+ parser.add_argument("--target-model-path", type=str, required=True)
44
+ parser.add_argument("--delta-path", type=str, required=True)
45
+
46
+ args = parser.parse_args()
47
+
48
+ apply_delta(args.base_model_path, args.target_model_path, args.delta_path)
llava/model/builder.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Haotian Liu
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ import os
17
+ import warnings
18
+ import shutil
19
+
20
+ from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig
21
+ import torch
22
+ from llava.model import *
23
+ from llava.constants import DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
24
+
25
+
26
+ def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto", device="cuda"):
27
+ kwargs = {"device_map": device_map}
28
+
29
+ if load_8bit:
30
+ kwargs['load_in_8bit'] = True
31
+ elif load_4bit:
32
+ kwargs['load_in_4bit'] = True
33
+ kwargs['quantization_config'] = BitsAndBytesConfig(
34
+ load_in_4bit=True,
35
+ bnb_4bit_compute_dtype=torch.float16,
36
+ bnb_4bit_use_double_quant=True,
37
+ bnb_4bit_quant_type='nf4'
38
+ )
39
+ else:
40
+ kwargs['torch_dtype'] = torch.float16
41
+
42
+ if 'llava' in model_name.lower():
43
+ # Load LLaVA model
44
+ if 'lora' in model_name.lower() and model_base is None:
45
+ warnings.warn('There is `lora` in model name but no `model_base` is provided. If you are loading a LoRA model, please provide the `model_base` argument. Detailed instruction: https://github.com/haotian-liu/LLaVA#launch-a-model-worker-lora-weights-unmerged.')
46
+ if 'lora' in model_name.lower() and model_base is not None:
47
+ lora_cfg_pretrained = AutoConfig.from_pretrained(model_path)
48
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
49
+ print('Loading LLaVA from base model...')
50
+ model = LlavaLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, **kwargs)
51
+ token_num, tokem_dim = model.lm_head.out_features, model.lm_head.in_features
52
+ if model.lm_head.weight.shape[0] != token_num:
53
+ model.lm_head.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype))
54
+ model.model.embed_tokens.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype))
55
+
56
+ print('Loading additional LLaVA weights...')
57
+ if os.path.exists(os.path.join(model_path, 'non_lora_trainables.bin')):
58
+ non_lora_trainables = torch.load(os.path.join(model_path, 'non_lora_trainables.bin'), map_location='cpu')
59
+ else:
60
+ # this is probably from HF Hub
61
+ from huggingface_hub import hf_hub_download
62
+ def load_from_hf(repo_id, filename, subfolder=None):
63
+ cache_file = hf_hub_download(
64
+ repo_id=repo_id,
65
+ filename=filename,
66
+ subfolder=subfolder)
67
+ return torch.load(cache_file, map_location='cpu')
68
+ non_lora_trainables = load_from_hf(model_path, 'non_lora_trainables.bin')
69
+ non_lora_trainables = {(k[11:] if k.startswith('base_model.') else k): v for k, v in non_lora_trainables.items()}
70
+ if any(k.startswith('model.model.') for k in non_lora_trainables):
71
+ non_lora_trainables = {(k[6:] if k.startswith('model.') else k): v for k, v in non_lora_trainables.items()}
72
+ model.load_state_dict(non_lora_trainables, strict=False)
73
+
74
+ from peft import PeftModel
75
+ print('Loading LoRA weights...')
76
+ model = PeftModel.from_pretrained(model, model_path)
77
+ print('Merging LoRA weights...')
78
+ model = model.merge_and_unload()
79
+ print('Model is loaded...')
80
+ elif model_base is not None:
81
+ # this may be mm projector only
82
+ print('Loading LLaVA from base model...')
83
+ if 'mpt' in model_name.lower():
84
+ if not os.path.isfile(os.path.join(model_path, 'configuration_mpt.py')):
85
+ shutil.copyfile(os.path.join(model_base, 'configuration_mpt.py'), os.path.join(model_path, 'configuration_mpt.py'))
86
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=True)
87
+ cfg_pretrained = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
88
+ model = LlavaMPTForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs)
89
+ else:
90
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
91
+ cfg_pretrained = AutoConfig.from_pretrained(model_path)
92
+ model = LlavaLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs)
93
+
94
+ mm_projector_weights = torch.load(os.path.join(model_path, 'mm_projector.bin'), map_location='cpu')
95
+ mm_projector_weights = {k: v.to(torch.float16) for k, v in mm_projector_weights.items()}
96
+ model.load_state_dict(mm_projector_weights, strict=False)
97
+ else:
98
+ if 'mpt' in model_name.lower():
99
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
100
+ model = LlavaMPTForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
101
+ else:
102
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
103
+ model = LlavaLlamaForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
104
+ else:
105
+ # Load language model
106
+ if model_base is not None:
107
+ # PEFT model
108
+ from peft import PeftModel
109
+ tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
110
+ model = AutoModelForCausalLM.from_pretrained(model_base, torch_dtype=torch.float16, low_cpu_mem_usage=True, device_map="auto")
111
+ print(f"Loading LoRA weights from {model_path}")
112
+ model = PeftModel.from_pretrained(model, model_path)
113
+ print(f"Merging weights")
114
+ model = model.merge_and_unload()
115
+ print('Convert to FP16...')
116
+ model.to(torch.float16)
117
+ else:
118
+ use_fast = False
119
+ if 'mpt' in model_name.lower():
120
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
121
+ model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, trust_remote_code=True, **kwargs)
122
+ else:
123
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
124
+ model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
125
+
126
+ image_processor = None
127
+
128
+ if 'llava' in model_name.lower():
129
+ mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
130
+ mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True)
131
+ if mm_use_im_patch_token:
132
+ tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
133
+ if mm_use_im_start_end:
134
+ tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
135
+ model.resize_token_embeddings(len(tokenizer))
136
+
137
+ vision_tower = model.get_vision_tower()
138
+ if not vision_tower.is_loaded:
139
+ vision_tower.load_model()
140
+ vision_tower.to(device=device, dtype=torch.float16)
141
+ image_processor = vision_tower.image_processor
142
+
143
+ if hasattr(model.config, "max_sequence_length"):
144
+ context_len = model.config.max_sequence_length
145
+ else:
146
+ context_len = 2048
147
+
148
+ return tokenizer, model, image_processor, context_len
llava/model/consolidate.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Usage:
3
+ python3 -m llava.model.consolidate --src ~/model_weights/llava-7b --dst ~/model_weights/llava-7b_consolidate
4
+ """
5
+ import argparse
6
+
7
+ import torch
8
+ from transformers import AutoTokenizer, AutoModelForCausalLM
9
+ from llava.model import *
10
+ from llava.model.utils import auto_upgrade
11
+
12
+
13
+ def consolidate_ckpt(src_path, dst_path):
14
+ print("Loading model")
15
+ auto_upgrade(src_path)
16
+ src_model = AutoModelForCausalLM.from_pretrained(src_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
17
+ src_tokenizer = AutoTokenizer.from_pretrained(src_path, use_fast=False)
18
+ src_model.save_pretrained(dst_path)
19
+ src_tokenizer.save_pretrained(dst_path)
20
+
21
+
22
+ if __name__ == "__main__":
23
+ parser = argparse.ArgumentParser()
24
+ parser.add_argument("--src", type=str, required=True)
25
+ parser.add_argument("--dst", type=str, required=True)
26
+
27
+ args = parser.parse_args()
28
+
29
+ consolidate_ckpt(args.src, args.dst)
llava/model/llava_arch.py ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Haotian Liu
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from abc import ABC, abstractmethod
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+
21
+ from .multimodal_encoder.builder import build_vision_tower
22
+ from .multimodal_projector.builder import build_vision_projector
23
+
24
+ from llava.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
25
+
26
+
27
+ class LlavaMetaModel:
28
+
29
+ def __init__(self, config):
30
+ super(LlavaMetaModel, self).__init__(config)
31
+
32
+ if hasattr(config, "mm_vision_tower"):
33
+ self.vision_tower = build_vision_tower(config, delay_load=True)
34
+ self.mm_projector = build_vision_projector(config)
35
+
36
+ def get_vision_tower(self):
37
+ vision_tower = getattr(self, 'vision_tower', None)
38
+ if type(vision_tower) is list:
39
+ vision_tower = vision_tower[0]
40
+ return vision_tower
41
+
42
+ def initialize_vision_modules(self, model_args, fsdp=None):
43
+ vision_tower = model_args.vision_tower
44
+ mm_vision_select_layer = model_args.mm_vision_select_layer
45
+ mm_vision_select_feature = model_args.mm_vision_select_feature
46
+ pretrain_mm_mlp_adapter = model_args.pretrain_mm_mlp_adapter
47
+
48
+ self.config.mm_vision_tower = vision_tower
49
+
50
+ if self.get_vision_tower() is None:
51
+ vision_tower = build_vision_tower(model_args)
52
+
53
+ if fsdp is not None and len(fsdp) > 0:
54
+ self.vision_tower = [vision_tower]
55
+ else:
56
+ self.vision_tower = vision_tower
57
+ else:
58
+ if fsdp is not None and len(fsdp) > 0:
59
+ vision_tower = self.vision_tower[0]
60
+ else:
61
+ vision_tower = self.vision_tower
62
+ vision_tower.load_model()
63
+
64
+ self.config.use_mm_proj = True
65
+ self.config.mm_projector_type = getattr(model_args, 'mm_projector_type', 'linear')
66
+ self.config.mm_hidden_size = vision_tower.hidden_size
67
+ self.config.mm_vision_select_layer = mm_vision_select_layer
68
+ self.config.mm_vision_select_feature = mm_vision_select_feature
69
+
70
+ if getattr(self, 'mm_projector', None) is None:
71
+ self.mm_projector = build_vision_projector(self.config)
72
+
73
+ if pretrain_mm_mlp_adapter is not None:
74
+ mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu')
75
+ def get_w(weights, keyword):
76
+ return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k}
77
+
78
+ self.mm_projector.load_state_dict(get_w(mm_projector_weights, 'mm_projector'))
79
+
80
+
81
+ class LlavaMetaForCausalLM(ABC):
82
+
83
+ @abstractmethod
84
+ def get_model(self):
85
+ pass
86
+
87
+ def get_vision_tower(self):
88
+ return self.get_model().get_vision_tower()
89
+
90
+ def encode_images(self, images):
91
+ image_features = self.get_model().get_vision_tower()(images)
92
+ image_features = self.get_model().mm_projector(image_features)
93
+ return image_features
94
+
95
+ def prepare_inputs_labels_for_multimodal(
96
+ self, input_ids, attention_mask, past_key_values, labels, images
97
+ ):
98
+ vision_tower = self.get_vision_tower()
99
+ if vision_tower is None or images is None or input_ids.shape[1] == 1:
100
+ if past_key_values is not None and vision_tower is not None and images is not None and input_ids.shape[1] == 1:
101
+ attention_mask = torch.ones((attention_mask.shape[0], past_key_values[-1][-1].shape[-2] + 1), dtype=attention_mask.dtype, device=attention_mask.device)
102
+ return input_ids, attention_mask, past_key_values, None, labels
103
+
104
+ if type(images) is list or images.ndim == 5:
105
+ concat_images = torch.cat([image for image in images], dim=0)
106
+ image_features = self.encode_images(concat_images)
107
+ split_sizes = [image.shape[0] for image in images]
108
+ image_features = torch.split(image_features, split_sizes, dim=0)
109
+ image_features = [x.flatten(0, 1) for x in image_features]
110
+ else:
111
+ image_features = self.encode_images(images)
112
+
113
+ new_input_embeds = []
114
+ new_labels = [] if labels is not None else None
115
+ cur_image_idx = 0
116
+ for batch_idx, cur_input_ids in enumerate(input_ids):
117
+ if (cur_input_ids == IMAGE_TOKEN_INDEX).sum() == 0:
118
+ # multimodal LLM, but the current sample is not multimodal
119
+ # FIXME: this is a hacky fix, for deepspeed zero3 to work
120
+ half_len = cur_input_ids.shape[0] // 2
121
+ cur_image_features = image_features[cur_image_idx]
122
+ cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids[:half_len])
123
+ cur_input_embeds_2 = self.get_model().embed_tokens(cur_input_ids[half_len:])
124
+ cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0], cur_input_embeds_2], dim=0)
125
+ new_input_embeds.append(cur_input_embeds)
126
+ if labels is not None:
127
+ new_labels.append(labels[batch_idx])
128
+ cur_image_idx += 1
129
+ continue
130
+ image_token_indices = torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0]
131
+ cur_new_input_embeds = []
132
+ if labels is not None:
133
+ cur_labels = labels[batch_idx]
134
+ cur_new_labels = []
135
+ assert cur_labels.shape == cur_input_ids.shape
136
+ while image_token_indices.numel() > 0:
137
+ cur_image_features = image_features[cur_image_idx]
138
+ image_token_start = image_token_indices[0]
139
+ if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False):
140
+ cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[:image_token_start-1]).detach())
141
+ cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[image_token_start-1:image_token_start]))
142
+ cur_new_input_embeds.append(cur_image_features)
143
+ cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[image_token_start+1:image_token_start+2]))
144
+ if labels is not None:
145
+ cur_new_labels.append(cur_labels[:image_token_start])
146
+ cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=labels.device, dtype=labels.dtype))
147
+ cur_new_labels.append(cur_labels[image_token_start:image_token_start+1])
148
+ cur_labels = cur_labels[image_token_start+2:]
149
+ else:
150
+ cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[:image_token_start]))
151
+ cur_new_input_embeds.append(cur_image_features)
152
+ if labels is not None:
153
+ cur_new_labels.append(cur_labels[:image_token_start])
154
+ cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=labels.device, dtype=labels.dtype))
155
+ cur_labels = cur_labels[image_token_start+1:]
156
+ cur_image_idx += 1
157
+ if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False):
158
+ cur_input_ids = cur_input_ids[image_token_start+2:]
159
+ else:
160
+ cur_input_ids = cur_input_ids[image_token_start+1:]
161
+ image_token_indices = torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0]
162
+ if cur_input_ids.numel() > 0:
163
+ if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False):
164
+ cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids).detach())
165
+ else:
166
+ cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids))
167
+ if labels is not None:
168
+ cur_new_labels.append(cur_labels)
169
+ cur_new_input_embeds = [x.to(device=self.device) for x in cur_new_input_embeds]
170
+ cur_new_input_embeds = torch.cat(cur_new_input_embeds, dim=0)
171
+ new_input_embeds.append(cur_new_input_embeds)
172
+ if labels is not None:
173
+ cur_new_labels = torch.cat(cur_new_labels, dim=0)
174
+ new_labels.append(cur_new_labels)
175
+
176
+ if any(x.shape != new_input_embeds[0].shape for x in new_input_embeds):
177
+ max_len = max(x.shape[0] for x in new_input_embeds)
178
+
179
+ new_input_embeds_align = []
180
+ for cur_new_embed in new_input_embeds:
181
+ cur_new_embed = torch.cat((cur_new_embed, torch.zeros((max_len - cur_new_embed.shape[0], cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)), dim=0)
182
+ new_input_embeds_align.append(cur_new_embed)
183
+ new_input_embeds = torch.stack(new_input_embeds_align, dim=0)
184
+
185
+ if labels is not None:
186
+ new_labels_align = []
187
+ _new_labels = new_labels
188
+ for cur_new_label in new_labels:
189
+ cur_new_label = torch.cat((cur_new_label, torch.full((max_len - cur_new_label.shape[0],), IGNORE_INDEX, dtype=cur_new_label.dtype, device=cur_new_label.device)), dim=0)
190
+ new_labels_align.append(cur_new_label)
191
+ new_labels = torch.stack(new_labels_align, dim=0)
192
+
193
+ if attention_mask is not None:
194
+ new_attention_mask = []
195
+ for cur_attention_mask, cur_new_labels, cur_new_labels_align in zip(attention_mask, _new_labels, new_labels):
196
+ new_attn_mask_pad_left = torch.full((cur_new_labels.shape[0] - labels.shape[1],), True, dtype=attention_mask.dtype, device=attention_mask.device)
197
+ new_attn_mask_pad_right = torch.full((cur_new_labels_align.shape[0] - cur_new_labels.shape[0],), False, dtype=attention_mask.dtype, device=attention_mask.device)
198
+ cur_new_attention_mask = torch.cat((new_attn_mask_pad_left, cur_attention_mask, new_attn_mask_pad_right), dim=0)
199
+ new_attention_mask.append(cur_new_attention_mask)
200
+ attention_mask = torch.stack(new_attention_mask, dim=0)
201
+ assert attention_mask.shape == new_labels.shape
202
+ else:
203
+ new_input_embeds = torch.stack(new_input_embeds, dim=0)
204
+ if labels is not None:
205
+ new_labels = torch.stack(new_labels, dim=0)
206
+
207
+ if attention_mask is not None:
208
+ new_attn_mask_pad_left = torch.full((attention_mask.shape[0], new_input_embeds.shape[1] - input_ids.shape[1]), True, dtype=attention_mask.dtype, device=attention_mask.device)
209
+ attention_mask = torch.cat((new_attn_mask_pad_left, attention_mask), dim=1)
210
+ assert attention_mask.shape == new_input_embeds.shape[:2]
211
+
212
+ return None, attention_mask, past_key_values, new_input_embeds, new_labels
213
+
214
+ def initialize_vision_tokenizer(self, model_args, tokenizer):
215
+ if model_args.mm_use_im_patch_token:
216
+ tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
217
+ self.resize_token_embeddings(len(tokenizer))
218
+
219
+ if model_args.mm_use_im_start_end:
220
+ num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
221
+ self.resize_token_embeddings(len(tokenizer))
222
+
223
+ if num_new_tokens > 0:
224
+ input_embeddings = self.get_input_embeddings().weight.data
225
+ output_embeddings = self.get_output_embeddings().weight.data
226
+
227
+ input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
228
+ dim=0, keepdim=True)
229
+ output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
230
+ dim=0, keepdim=True)
231
+
232
+ input_embeddings[-num_new_tokens:] = input_embeddings_avg
233
+ output_embeddings[-num_new_tokens:] = output_embeddings_avg
234
+
235
+ if model_args.tune_mm_mlp_adapter:
236
+ for p in self.get_input_embeddings().parameters():
237
+ p.requires_grad = True
238
+ for p in self.get_output_embeddings().parameters():
239
+ p.requires_grad = False
240
+
241
+ if model_args.pretrain_mm_mlp_adapter:
242
+ mm_projector_weights = torch.load(model_args.pretrain_mm_mlp_adapter, map_location='cpu')
243
+ embed_tokens_weight = mm_projector_weights['model.embed_tokens.weight']
244
+ assert num_new_tokens == 2
245
+ if input_embeddings.shape == embed_tokens_weight.shape:
246
+ input_embeddings[-num_new_tokens:] = embed_tokens_weight[-num_new_tokens:]
247
+ elif embed_tokens_weight.shape[0] == num_new_tokens:
248
+ input_embeddings[-num_new_tokens:] = embed_tokens_weight
249
+ else:
250
+ raise ValueError(f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}.")
251
+ elif model_args.mm_use_im_patch_token:
252
+ if model_args.tune_mm_mlp_adapter:
253
+ for p in self.get_input_embeddings().parameters():
254
+ p.requires_grad = False
255
+ for p in self.get_output_embeddings().parameters():
256
+ p.requires_grad = False
llava/model/make_delta.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Usage:
3
+ python3 -m llava.model.make_delta --base ~/model_weights/llama-7b --target ~/model_weights/llava-7b --delta ~/model_weights/llava-7b-delta --hub-repo-id liuhaotian/llava-7b-delta
4
+ """
5
+ import argparse
6
+
7
+ import torch
8
+ from tqdm import tqdm
9
+ from transformers import AutoTokenizer, AutoModelForCausalLM
10
+ from llava.model.utils import auto_upgrade
11
+
12
+
13
+ def make_delta(base_model_path, target_model_path, delta_path, hub_repo_id):
14
+ print("Loading base model")
15
+ base = AutoModelForCausalLM.from_pretrained(
16
+ base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
17
+
18
+ print("Loading target model")
19
+ auto_upgrade(target_model_path)
20
+ target = AutoModelForCausalLM.from_pretrained(target_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
21
+
22
+ print("Calculating delta")
23
+ for name, param in tqdm(target.state_dict().items(), desc="Calculating delta"):
24
+ if name not in base.state_dict():
25
+ assert name in ['model.mm_projector.weight', 'model.mm_projector.bias'], f'{name} not in base model'
26
+ continue
27
+ if param.data.shape == base.state_dict()[name].shape:
28
+ param.data -= base.state_dict()[name]
29
+ else:
30
+ assert name in ['model.embed_tokens.weight', 'lm_head.weight'], f'{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}'
31
+ bparam = base.state_dict()[name]
32
+ param.data[:bparam.shape[0], :bparam.shape[1]] -= bparam
33
+
34
+ print("Saving delta")
35
+ if hub_repo_id:
36
+ kwargs = {"push_to_hub": True, "repo_id": hub_repo_id}
37
+ else:
38
+ kwargs = {}
39
+ target.save_pretrained(delta_path, **kwargs)
40
+ target_tokenizer = AutoTokenizer.from_pretrained(target_model_path)
41
+ target_tokenizer.save_pretrained(delta_path, **kwargs)
42
+
43
+
44
+ if __name__ == "__main__":
45
+ parser = argparse.ArgumentParser()
46
+ parser.add_argument("--base-model-path", type=str, required=True)
47
+ parser.add_argument("--target-model-path", type=str, required=True)
48
+ parser.add_argument("--delta-path", type=str, required=True)
49
+ parser.add_argument("--hub-repo-id", type=str, default=None)
50
+ args = parser.parse_args()
51
+
52
+ make_delta(args.base_model_path, args.target_model_path, args.delta_path, args.hub_repo_id)
llava/model/utils.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoConfig
2
+
3
+
4
+ def auto_upgrade(config):
5
+ cfg = AutoConfig.from_pretrained(config)
6
+ if 'llava' in config and 'llava' not in cfg.model_type:
7
+ assert cfg.model_type == 'llama'
8
+ print("You are using newer LLaVA code base, while the checkpoint of v0 is from older code base.")
9
+ print("You must upgrade the checkpoint to the new code base (this can be done automatically).")
10
+ confirm = input("Please confirm that you want to upgrade the checkpoint. [Y/N]")
11
+ if confirm.lower() in ["y", "yes"]:
12
+ print("Upgrading checkpoint...")
13
+ assert len(cfg.architectures) == 1
14
+ setattr(cfg.__class__, "model_type", "llava")
15
+ cfg.architectures[0] = 'LlavaLlamaForCausalLM'
16
+ cfg.save_pretrained(config)
17
+ print("Checkpoint upgraded.")
18
+ else:
19
+ print("Checkpoint upgrade aborted.")
20
+ exit(1)