toshi456 commited on
Commit
7d0ed79
1 Parent(s): 744af63

Upload 14 files

Browse files
llava/constants.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ CONTROLLER_HEART_BEAT_EXPIRATION = 30
2
+ WORKER_HEART_BEAT_INTERVAL = 15
3
+
4
+ LOGDIR = "."
5
+
6
+ # Model Constants
7
+ IGNORE_INDEX = -100
8
+ IMAGE_TOKEN_INDEX = -200
9
+ DEFAULT_IMAGE_TOKEN = "<image>"
llava/conversation.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+ from enum import auto, Enum
3
+ from typing import List, Tuple
4
+
5
+
6
+ class SeparatorStyle(Enum):
7
+ """Different separator style."""
8
+ SINGLE = auto()
9
+ PLAIN = auto()
10
+ TWO = auto()
11
+
12
+
13
+ @dataclasses.dataclass
14
+ class Conversation:
15
+ """A class that keeps all conversation history."""
16
+ system: str
17
+ roles: List[str]
18
+ messages: List[List[str]]
19
+ offset: int
20
+ sep_style: SeparatorStyle = SeparatorStyle.SINGLE
21
+ sep: str = "###"
22
+ sep2: str = None
23
+ version: str = "Unknown"
24
+
25
+ skip_next: bool = False
26
+
27
+ def get_prompt(self):
28
+ messages = self.messages
29
+ if len(messages) > 0 and type(messages[0][1]) is tuple:
30
+ messages = self.messages.copy()
31
+ init_role, init_msg = messages[0].copy()
32
+ init_msg = init_msg[0].replace("<image>", "").strip()
33
+ messages[0] = (init_role, "<image>\n" + init_msg)
34
+
35
+ if self.sep_style == SeparatorStyle.SINGLE:
36
+ ret = self.system + self.sep
37
+ for role, message in messages:
38
+ if message:
39
+ if type(message) is tuple:
40
+ message, _, _ = message
41
+ ret += role + ": " + message + self.sep
42
+ else:
43
+ ret += role + ":"
44
+ elif self.sep_style == SeparatorStyle.TWO:
45
+ seps = [self.sep, self.sep2]
46
+ ret = self.system + seps[0]
47
+ for i, (role, message) in enumerate(messages):
48
+ if message:
49
+ if type(message) is tuple:
50
+ message, _, _ = message
51
+ ret += role + ": " + message + seps[i % 2]
52
+ else:
53
+ ret += role + ": "
54
+ elif self.sep_style == SeparatorStyle.PLAIN:
55
+ seps = [self.sep, self.sep2]
56
+ ret = self.system
57
+ for i, (role, message) in enumerate(messages):
58
+ if message:
59
+ if type(message) is tuple:
60
+ message, _, _ = message
61
+ ret += message + seps[i % 2]
62
+ else:
63
+ ret += ""
64
+ else:
65
+ raise ValueError(f"Invalid style: {self.sep_style}")
66
+
67
+ return ret
68
+
69
+ def append_message(self, role, message):
70
+ self.messages.append([role, message])
71
+
72
+ def get_images(self, return_pil=False):
73
+ images = []
74
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
75
+ if i % 2 == 0:
76
+ if type(msg) is tuple:
77
+ import base64
78
+ from io import BytesIO
79
+ from PIL import Image
80
+ msg, image, image_process_mode = msg
81
+ if image_process_mode == "Pad":
82
+ def expand2square(pil_img, background_color=(122, 116, 104)):
83
+ width, height = pil_img.size
84
+ if width == height:
85
+ return pil_img
86
+ elif width > height:
87
+ result = Image.new(pil_img.mode, (width, width), background_color)
88
+ result.paste(pil_img, (0, (width - height) // 2))
89
+ return result
90
+ else:
91
+ result = Image.new(pil_img.mode, (height, height), background_color)
92
+ result.paste(pil_img, ((height - width) // 2, 0))
93
+ return result
94
+ image = expand2square(image)
95
+ elif image_process_mode in ["Default", "Crop"]:
96
+ pass
97
+ elif image_process_mode == "Resize":
98
+ image = image.resize((336, 336))
99
+ else:
100
+ raise ValueError(f"Invalid image_process_mode: {image_process_mode}")
101
+ max_hw, min_hw = max(image.size), min(image.size)
102
+ aspect_ratio = max_hw / min_hw
103
+ max_len, min_len = 800, 400
104
+ shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
105
+ longest_edge = int(shortest_edge * aspect_ratio)
106
+ W, H = image.size
107
+ if longest_edge != max(image.size):
108
+ if H > W:
109
+ H, W = longest_edge, shortest_edge
110
+ else:
111
+ H, W = shortest_edge, longest_edge
112
+ image = image.resize((W, H))
113
+ if return_pil:
114
+ images.append(image)
115
+ else:
116
+ buffered = BytesIO()
117
+ image.save(buffered, format="PNG")
118
+ img_b64_str = base64.b64encode(buffered.getvalue()).decode()
119
+ images.append(img_b64_str)
120
+ return images
121
+
122
+ def to_gradio_chatbot(self):
123
+ ret = []
124
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
125
+ if i % 2 == 0:
126
+ if type(msg) is tuple:
127
+ import base64
128
+ from io import BytesIO
129
+ msg, image, image_process_mode = msg
130
+ max_hw, min_hw = max(image.size), min(image.size)
131
+ aspect_ratio = max_hw / min_hw
132
+ max_len, min_len = 800, 400
133
+ shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
134
+ longest_edge = int(shortest_edge * aspect_ratio)
135
+ W, H = image.size
136
+ if H > W:
137
+ H, W = longest_edge, shortest_edge
138
+ else:
139
+ H, W = shortest_edge, longest_edge
140
+ image = image.resize((W, H))
141
+ buffered = BytesIO()
142
+ image.save(buffered, format="JPEG")
143
+ img_b64_str = base64.b64encode(buffered.getvalue()).decode()
144
+ img_str = f'<img src="data:image/png;base64,{img_b64_str}" alt="user upload image" />'
145
+ msg = img_str + msg.replace('<image>', '').strip()
146
+ ret.append([msg, None])
147
+ else:
148
+ ret.append([msg, None])
149
+ else:
150
+ ret[-1][-1] = msg
151
+ return ret
152
+
153
+ def copy(self):
154
+ return Conversation(
155
+ system=self.system,
156
+ roles=self.roles,
157
+ messages=[[x, y] for x, y in self.messages],
158
+ offset=self.offset,
159
+ sep_style=self.sep_style,
160
+ sep=self.sep,
161
+ sep2=self.sep2,
162
+ version=self.version)
163
+
164
+ def dict(self):
165
+ if len(self.get_images()) > 0:
166
+ return {
167
+ "system": self.system,
168
+ "roles": self.roles,
169
+ "messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
170
+ "offset": self.offset,
171
+ "sep": self.sep,
172
+ "sep2": self.sep2,
173
+ }
174
+ return {
175
+ "system": self.system,
176
+ "roles": self.roles,
177
+ "messages": self.messages,
178
+ "offset": self.offset,
179
+ "sep": self.sep,
180
+ "sep2": self.sep2,
181
+ }
182
+
183
+ conv_vicuna_v1 = Conversation(
184
+ system="これは好奇心旺盛なユーザーと人工知能システムのチャットです。"
185
+ "システムはユーザーの質問に親切、詳細、丁寧に答える。",
186
+ roles=("ユーザー", "システム"),
187
+ version="v1",
188
+ messages=(),
189
+ offset=0,
190
+ sep_style=SeparatorStyle.TWO,
191
+ sep=" ",
192
+ sep2="<EOD|LLM-jp>", # if you use llm-jp : <EOD|LLM-jp>, gpt2 and gpt_neox: </s>
193
+ )
194
+
195
+ conv_llava_plain = Conversation(
196
+ system="",
197
+ roles=("", ""),
198
+ messages=(
199
+ ),
200
+ offset=0,
201
+ sep_style=SeparatorStyle.PLAIN,
202
+ sep="\n",
203
+ )
204
+
205
+ default_conversation = conv_llava_plain
206
+ conv_templates = {
207
+ "v1": conv_vicuna_v1,
208
+ "plain": conv_llava_plain,
209
+ }
210
+
211
+
212
+ if __name__ == "__main__":
213
+ print(default_conversation.get_prompt())
llava/model/clip_encoder.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ from transformers import (
7
+ CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig,\
8
+ SiglipVisionModel, SiglipImageProcessor, SiglipVisionConfig
9
+ )
10
+ from llava.s2wrapper import forward as multiscale_forward
11
+
12
+
13
+ class CLIPVisionTower(nn.Module):
14
+ def __init__(
15
+ self,
16
+ vision_tower_name: str="openai/clip-vit-large-patch14-336",
17
+ mm_vision_select_layer: int=-2, # v1.5 is -2
18
+ mm_vision_select_feature: str="patch",
19
+ delay_load: bool=False,
20
+ requires_grad: bool=False,
21
+ scales: Optional[float] = None
22
+ ):
23
+ super().__init__()
24
+
25
+ self.is_loaded = False
26
+ self.requires_grad = requires_grad
27
+ self.scales = scales
28
+
29
+ self.vision_tower_name = vision_tower_name
30
+ self.select_layer = mm_vision_select_layer
31
+ self.select_feature = mm_vision_select_feature
32
+
33
+ self.image_processor = None
34
+ self.vision_tower = None
35
+
36
+ if not delay_load:
37
+ self.load_model()
38
+ else:
39
+ if "clip" in self.vision_tower_name:
40
+ self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name)
41
+ elif "siglip" in self.vision_tower_name:
42
+ self.cfg_only = SiglipVisionConfig.from_pretrained(self.vision_tower_name)
43
+ else:
44
+ raise ValueError(f'Unsupported vision_tower_name: {self.vision_tower_name}')
45
+
46
+ def load_model(self):
47
+ if "clip" in self.vision_tower_name:
48
+ self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name)
49
+ self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name)
50
+ elif "siglip" in self.vision_tower_name:
51
+ self.image_processor = SiglipImageProcessor.from_pretrained(self.vision_tower_name)
52
+ self.vision_tower = SiglipVisionModel.from_pretrained(self.vision_tower_name)
53
+ else:
54
+ raise ValueError(f'Unsupported vision_tower_name: {self.vision_tower_name}')
55
+ self.vision_tower.requires_grad_(self.requires_grad)
56
+
57
+ self.is_loaded = True
58
+
59
+ def feature_select(self, image_forward_outs):
60
+ image_features = image_forward_outs.hidden_states[self.select_layer]
61
+ if self.select_feature == 'patch':
62
+ image_features = image_features[:, 1:]
63
+ elif self.select_feature == 'cls_patch':
64
+ image_features = image_features
65
+ else:
66
+ raise ValueError(f'Unexpected select feature: {self.select_feature}')
67
+ return image_features
68
+
69
+ @torch.no_grad()
70
+ def forward(self, images):
71
+ if type(images) is list:
72
+ image_features = []
73
+ for image in images:
74
+ if self.scales is None:
75
+ image_feature = self._forward_feature(images.unsqueeze(0))
76
+ else:
77
+ image_feature = multiscale_forward(
78
+ self._forward_feature,
79
+ images.unsqueeze(0),
80
+ scales=self.scales,
81
+ num_prefix_token=0,
82
+ max_split_size=self.image_processor.size["height"]
83
+ )
84
+ #image_feature = self.feature_select(image_forward_out).to(image.dtype)
85
+ image_features.append(image_feature)
86
+ else:
87
+ if self.scales is None:
88
+ image_features = self._forward_feature(images)
89
+ else:
90
+ image_features = multiscale_forward(
91
+ self._forward_feature,
92
+ images,
93
+ scales=self.scales,
94
+ num_prefix_token=0,
95
+ max_split_size=self.image_processor.size["height"]
96
+ )
97
+ #image_features = self.feature_select(image_forward_outs).to(images.dtype)
98
+
99
+ return image_features
100
+
101
+ def _forward_feature(self, inputs):
102
+ return self.feature_select(self.vision_tower(inputs.to(device=self.device, dtype=self.dtype), output_hidden_states=True))
103
+
104
+ @property
105
+ def dummy_feature(self):
106
+ return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
107
+
108
+ @property
109
+ def dtype(self):
110
+ return self.vision_tower.dtype
111
+
112
+ @property
113
+ def device(self):
114
+ return self.vision_tower.device
115
+
116
+ @property
117
+ def config(self):
118
+ if self.is_loaded:
119
+ return self.vision_tower.config
120
+ else:
121
+ return self.cfg_only
122
+
123
+ @property
124
+ def hidden_size(self):
125
+ if self.scales is None:
126
+ return self.config.hidden_size
127
+
128
+ return self.config.hidden_size*len(self.scales)
129
+
130
+ @property
131
+ def num_patches(self):
132
+ return (self.config.image_size // self.config.patch_size) ** 2
llava/model/llava_arch.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Haotian Liu
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from abc import ABC, abstractmethod
17
+
18
+ import torch
19
+
20
+ from llava.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX
21
+ from llava.model.clip_encoder import CLIPVisionTower
22
+ from llava.model.vision_projector import get_vision_projector
23
+
24
+
25
+ class LlavaMetaModel:
26
+
27
+ def __init__(self, config):
28
+ super(LlavaMetaModel, self).__init__(config)
29
+ #self.config = config
30
+
31
+ if hasattr(config, "mm_vision_tower"):
32
+ self.initialize_vision_modules(config)
33
+ else:
34
+ self.vision_tower = None
35
+ self.mm_projector = None
36
+
37
+ def get_vision_tower(self):
38
+ vision_tower = getattr(self, 'vision_tower', None)
39
+ if type(vision_tower) is list:
40
+ vision_tower = vision_tower[0]
41
+ return vision_tower
42
+
43
+ def initialize_vision_modules(self, model_args):
44
+ vision_tower = model_args.vision_tower if hasattr(model_args, "vision_tower") else model_args.mm_vision_tower
45
+ mm_vision_select_layer = model_args.mm_vision_select_layer
46
+ mm_vision_select_feature = model_args.mm_vision_select_feature
47
+ pretrain_mm_mlp_adapter = model_args.pretrain_mm_mlp_adapter if hasattr(model_args, "pretrain_mm_mlp_adapter") else None
48
+
49
+ self.config.mm_vision_tower = vision_tower
50
+ self.config.scales = model_args.scales if hasattr(model_args, 'scales') else None
51
+
52
+ self.vision_tower = CLIPVisionTower(
53
+ vision_tower,
54
+ mm_vision_select_layer,
55
+ mm_vision_select_feature,
56
+ delay_load=True,
57
+ scales=model_args.scales,
58
+ )
59
+ self.vision_tower.load_model()
60
+
61
+ self.config.use_mm_proj = True
62
+ self.config.mm_projector_type = getattr(model_args, 'mm_projector_type', 'linear')
63
+ self.config.mm_hidden_size = self.vision_tower.hidden_size
64
+ self.config.mm_vision_select_layer = mm_vision_select_layer
65
+ self.config.mm_vision_select_feature = mm_vision_select_feature
66
+
67
+ self.mm_projector = get_vision_projector(self.config)
68
+
69
+ # In case it is frozen by LoRA
70
+ for p in self.mm_projector.parameters():
71
+ p.requires_grad = True
72
+
73
+ if pretrain_mm_mlp_adapter is not None:
74
+ mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu')
75
+ def get_w(weights, keyword):
76
+ return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k}
77
+
78
+ self.mm_projector.load_state_dict(get_w(mm_projector_weights, 'mm_projector'))
79
+
80
+
81
+ class LlavaMetaForCausalLM(ABC):
82
+ base_model = "" # gpt2 or llama or gptneox
83
+
84
+ @abstractmethod
85
+ def get_model(self):
86
+ pass
87
+
88
+ def get_vision_tower(self):
89
+ return self.get_model().get_vision_tower()
90
+
91
+ def encode_images(self, images):
92
+ image_features = self.get_model().get_vision_tower()(images)
93
+ image_features = self.get_model().mm_projector(image_features)
94
+ return image_features
95
+
96
+ def embed(self, input_ids):
97
+ if self.base_model == "gpt2":
98
+ return self.transformer.wte(input_ids)
99
+ elif self.base_model == "gpt_neox":
100
+ return self.embed_in(input_ids) # NeoX
101
+ elif self.base_model == "llama":
102
+ return self.get_model().embed_tokens(input_ids) # Llama
103
+
104
+ def prepare_inputs_labels_for_multimodal(
105
+ self, input_ids, position_ids, attention_mask, past_key_values, labels, images
106
+ ):
107
+ vision_tower = self.get_vision_tower()
108
+ if vision_tower is None or images is None or input_ids.shape[1] == 1:
109
+ if past_key_values is not None and vision_tower is not None and images is not None and input_ids.shape[1] == 1:
110
+ target_shape = past_key_values[-1][-1].shape[-2] + 1
111
+ attention_mask = torch.cat((attention_mask, torch.ones(
112
+ (attention_mask.shape[0], target_shape - attention_mask.shape[1]),
113
+ dtype=attention_mask.dtype,
114
+ device=attention_mask.device
115
+ )), dim=1)
116
+ position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
117
+ return input_ids, position_ids, attention_mask, past_key_values, None, labels
118
+
119
+ if type(images) is list or images.ndim == 5:
120
+ # 動画の場合の処理
121
+ concat_images = torch.cat([image for image in images], dim=0)
122
+ image_features = self.encode_images(concat_images)
123
+ split_sizes = [image.shape[0] for image in images]
124
+ image_features = torch.split(image_features, split_sizes, dim=0)
125
+ image_features = [x.flatten(0, 1).to(self.device) for x in image_features]
126
+ else:
127
+ image_features = self.encode_images(images).to(self.device)
128
+
129
+ # Let's just add dummy tensors if they do not exist,
130
+ # it is a headache to deal with None all the time.
131
+ # But it is not ideal, and if you have a better idea,
132
+ # please open an issue / submit a PR, thanks.
133
+ _labels = labels
134
+ _position_ids = position_ids
135
+ _attention_mask = attention_mask
136
+ if attention_mask is None:
137
+ attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
138
+ else:
139
+ attention_mask = attention_mask.bool()
140
+ if position_ids is None:
141
+ position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device)
142
+ if labels is None:
143
+ labels = torch.full_like(input_ids, IGNORE_INDEX)
144
+
145
+ # remove the padding using attention_mask -- TODO: double check
146
+ input_ids = [cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)]
147
+ labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)]
148
+
149
+ new_input_embeds = []
150
+ new_labels = []
151
+ cur_image_idx = 0
152
+ for batch_idx, cur_input_ids in enumerate(input_ids):
153
+ num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum()
154
+ if num_images == 0:
155
+ cur_image_features = image_features[cur_image_idx]
156
+ cur_input_embeds_1 = self.embed(cur_input_ids)
157
+ cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0]], dim=0)
158
+ new_input_embeds.append(cur_input_embeds)
159
+ new_labels.append(labels[batch_idx])
160
+ cur_image_idx += 1
161
+ continue
162
+
163
+ image_token_indices = [-1] + torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() + [cur_input_ids.shape[0]]
164
+ cur_input_ids_noim = []
165
+ cur_labels = labels[batch_idx]
166
+ cur_labels_noim = []
167
+
168
+ # IMAGE_TOKEN_INDEXで前後にtokenを分割
169
+ # ex. input_ids -> cur_input_ids_noim
170
+ # [1 2 3 -200 4 5 6] -> [1 2 3], [4 5 6]
171
+ for i in range(len(image_token_indices) - 1):
172
+ cur_input_ids_noim.append(cur_input_ids[image_token_indices[i]+1:image_token_indices[i+1]])
173
+ cur_labels_noim.append(cur_labels[image_token_indices[i]+1:image_token_indices[i+1]])
174
+ split_sizes = [x.shape[0] for x in cur_labels_noim]
175
+
176
+ # cur_input_embeds_no_im[0].size() (27, 768)
177
+ # cur_input_embeds_no_im[1].size() (xxx, 768)
178
+ cur_input_embeds = self.embed(torch.cat(cur_input_ids_noim))
179
+ cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0)
180
+ cur_new_input_embeds = []
181
+ cur_new_labels = []
182
+
183
+ # IMAGE_TOKEN_INDEXの部分を画像特徴量に置き換える
184
+ # cur_image_fearures.size() (576, 768)
185
+ for i in range(num_images + 1):
186
+ cur_new_input_embeds.append(cur_input_embeds_no_im[i])
187
+ cur_new_labels.append(cur_labels_noim[i])
188
+ if i < num_images:
189
+ cur_image_features = image_features[cur_image_idx]
190
+ cur_image_idx += 1
191
+ cur_new_input_embeds.append(cur_image_features)
192
+ cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=cur_labels.device, dtype=cur_labels.dtype))
193
+
194
+ cur_new_input_embeds = torch.cat(cur_new_input_embeds)
195
+ cur_new_labels = torch.cat(cur_new_labels)
196
+ new_input_embeds.append(cur_new_input_embeds)
197
+ new_labels.append(cur_new_labels)
198
+
199
+ # Truncate sequences to max length as image embeddings can make the sequence longer
200
+ tokenizer_model_max_length = getattr(self.config, 'tokenizer_model_max_length', None)
201
+ if tokenizer_model_max_length is not None:
202
+ new_input_embeds = [x[:tokenizer_model_max_length] for x in new_input_embeds]
203
+ new_labels = [x[:tokenizer_model_max_length] for x in new_labels]
204
+
205
+ # Combine them
206
+ max_len = max(x.shape[0] for x in new_input_embeds)
207
+ batch_size = len(new_input_embeds)
208
+
209
+ new_input_embeds_padded = []
210
+ new_labels_padded = torch.full((batch_size, max_len), IGNORE_INDEX, dtype=new_labels[0].dtype, device=new_labels[0].device)
211
+ attention_mask = torch.zeros((batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device)
212
+ position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device)
213
+
214
+ for i, (cur_new_embed, cur_new_labels) in enumerate(zip(new_input_embeds, new_labels)):
215
+ cur_len = cur_new_embed.shape[0]
216
+ if getattr(self.config, 'tokenizer_padding_side', 'right') == "left":
217
+ new_input_embeds_padded.append(torch.cat((
218
+ torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device),
219
+ cur_new_embed
220
+ ), dim=0))
221
+ if cur_len > 0:
222
+ new_labels_padded[i, -cur_len:] = cur_new_labels
223
+ attention_mask[i, -cur_len:] = True
224
+ position_ids[i, -cur_len:] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
225
+ else:
226
+ new_input_embeds_padded.append(torch.cat((
227
+ cur_new_embed,
228
+ torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)
229
+ ), dim=0))
230
+ if cur_len > 0:
231
+ new_labels_padded[i, :cur_len] = cur_new_labels
232
+ attention_mask[i, :cur_len] = True
233
+ position_ids[i, :cur_len] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
234
+
235
+ new_input_embeds = torch.stack(new_input_embeds_padded, dim=0)
236
+
237
+ if _labels is None:
238
+ new_labels = None
239
+ else:
240
+ new_labels = new_labels_padded
241
+
242
+ if _attention_mask is None:
243
+ attention_mask = None
244
+ else:
245
+ attention_mask = attention_mask.to(dtype=_attention_mask.dtype)
246
+
247
+ if _position_ids is None:
248
+ position_ids = None
249
+
250
+ return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels
llava/model/llava_gpt2.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Haotian Liu
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from typing import List, Optional, Tuple, Union
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+
21
+ from transformers import AutoConfig, AutoModelForCausalLM, \
22
+ GPT2LMHeadModel, GPT2Config, PreTrainedModel
23
+ from transformers.modeling_outputs import CausalLMOutputWithPast
24
+ from llava.model.llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
25
+
26
+
27
+ class LlavaConfig(GPT2Config):
28
+ model_type = "llava-jp"
29
+
30
+
31
+ class LlavaGpt2Model(LlavaMetaModel, PreTrainedModel):
32
+ config_class = LlavaConfig
33
+
34
+ def __init__(self, config: GPT2Config):
35
+ super(LlavaGpt2Model, self).__init__(config)
36
+
37
+
38
+ class LlavaGpt2ForCausalLM(GPT2LMHeadModel, LlavaMetaForCausalLM):
39
+ config_class = LlavaConfig
40
+ base_model = "gpt2"
41
+
42
+ def __init__(self, config):
43
+ super(LlavaGpt2ForCausalLM, self).__init__(config)
44
+ self.model = LlavaGpt2Model(config)
45
+ #self.model = LlavaMetaModel(config)
46
+ self.vocab_size = config.vocab_size
47
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
48
+
49
+ # Initialize weights and apply final processing
50
+ self.post_init()
51
+
52
+ def get_model(self):
53
+ return self.model
54
+
55
+ def forward(
56
+ self,
57
+ input_ids: torch.LongTensor = None,
58
+ attention_mask: Optional[torch.Tensor] = None,
59
+ position_ids: Optional[torch.LongTensor] = None,
60
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
61
+ inputs_embeds: Optional[torch.FloatTensor] = None,
62
+ labels: Optional[torch.LongTensor] = None,
63
+ use_cache: Optional[bool] = None,
64
+ output_attentions: Optional[bool] = None,
65
+ output_hidden_states: Optional[bool] = None,
66
+ images: Optional[torch.FloatTensor] = None,
67
+ return_dict: Optional[bool] = None,
68
+ **kwargs
69
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
70
+
71
+ if inputs_embeds is None:
72
+ (
73
+ input_ids,
74
+ position_ids,
75
+ attention_mask,
76
+ past_key_values,
77
+ inputs_embeds,
78
+ labels
79
+ ) = self.prepare_inputs_labels_for_multimodal(
80
+ input_ids,
81
+ position_ids,
82
+ attention_mask,
83
+ past_key_values,
84
+ labels,
85
+ images
86
+ )
87
+
88
+ return super().forward(
89
+ input_ids=input_ids,
90
+ attention_mask=attention_mask,
91
+ position_ids=position_ids,
92
+ past_key_values=past_key_values,
93
+ inputs_embeds=inputs_embeds,
94
+ labels=labels,
95
+ use_cache=use_cache,
96
+ output_attentions=output_attentions,
97
+ output_hidden_states=output_hidden_states,
98
+ return_dict=return_dict
99
+ )
100
+
101
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
102
+ images = kwargs.pop("images", None)
103
+ _inputs = super().prepare_inputs_for_generation(
104
+ input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
105
+ )
106
+ if images is not None:
107
+ _inputs['images'] = images
108
+ return _inputs
109
+
110
+ AutoConfig.register("llava-jp", LlavaConfig)
111
+ AutoModelForCausalLM.register(LlavaConfig, LlavaGpt2ForCausalLM)
llava/model/llava_gpt_neox.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Haotian Liu
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from typing import List, Optional, Tuple, Union
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+
21
+ from transformers import AutoConfig, AutoModelForCausalLM, \
22
+ GPTNeoXModel, GPTNeoXForCausalLM, GPTNeoXConfig, PreTrainedModel
23
+
24
+ from transformers.modeling_outputs import CausalLMOutputWithPast
25
+
26
+ from llava.model.llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
27
+
28
+
29
+ class LlavaConfig(GPTNeoXConfig):
30
+ model_type = "llava-jp"
31
+
32
+
33
+ class LlavaGptNeoxModel(LlavaMetaModel, GPTNeoXModel):
34
+ config_class = LlavaConfig
35
+
36
+ def __init__(self, config: GPTNeoXConfig):
37
+ super(LlavaGptNeoxModel, self).__init__(config)
38
+
39
+
40
+ class LlavaGptNeoxForCausalLM(PreTrainedModel, LlavaMetaForCausalLM):
41
+ config_class = LlavaConfig
42
+ base_model = "gpt_neox"
43
+
44
+ def __init__(self, config):
45
+ super(LlavaGptNeoxForCausalLM, self).__init__(config)
46
+ self.model = LlavaGptNeoxModel(config)
47
+ self.vocab_size = config.vocab_size
48
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
49
+
50
+ # Initialize weights and apply final processing
51
+ self.post_init()
52
+
53
+ def get_model(self):
54
+ return self.model
55
+
56
+ def forward(
57
+ self,
58
+ input_ids: torch.LongTensor = None,
59
+ attention_mask: Optional[torch.Tensor] = None,
60
+ position_ids: Optional[torch.LongTensor] = None,
61
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
62
+ inputs_embeds: Optional[torch.FloatTensor] = None,
63
+ labels: Optional[torch.LongTensor] = None,
64
+ use_cache: Optional[bool] = None,
65
+ output_attentions: Optional[bool] = None,
66
+ output_hidden_states: Optional[bool] = None,
67
+ images: Optional[torch.FloatTensor] = None,
68
+ return_dict: Optional[bool] = None,
69
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
70
+
71
+ if inputs_embeds is None:
72
+ (
73
+ input_ids,
74
+ position_ids,
75
+ attention_mask,
76
+ past_key_values,
77
+ inputs_embeds,
78
+ labels
79
+ ) = self.prepare_inputs_labels_for_multimodal(
80
+ input_ids,
81
+ position_ids,
82
+ attention_mask,
83
+ past_key_values,
84
+ labels,
85
+ images
86
+ )
87
+ print(inputs_embeds.size())
88
+
89
+ return super().forward(
90
+ input_ids=input_ids,
91
+ attention_mask=attention_mask,
92
+ position_ids=position_ids,
93
+ past_key_values=past_key_values,
94
+ inputs_embeds=inputs_embeds,
95
+ labels=labels,
96
+ use_cache=use_cache,
97
+ output_attentions=output_attentions,
98
+ output_hidden_states=output_hidden_states,
99
+ return_dict=return_dict
100
+ )
101
+
102
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
103
+ images = kwargs.pop("images", None)
104
+ _inputs = super().prepare_inputs_for_generation(
105
+ input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
106
+ )
107
+ if images is not None:
108
+ _inputs['images'] = images
109
+ return _inputs
110
+
111
+ AutoConfig.register("llava-jp", LlavaConfig)
112
+ AutoModelForCausalLM.register(LlavaConfig, LlavaGptNeoxForCausalLM)
llava/model/llava_llama.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Haotian Liu
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from typing import List, Optional, Tuple, Union
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+
21
+ from transformers import AutoConfig, AutoModelForCausalLM, LlamaForCausalLM, \
22
+ LlamaModel, LlamaConfig
23
+ from transformers.modeling_outputs import CausalLMOutputWithPast
24
+ from llava.model.llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
25
+
26
+
27
+ class LlavaConfig(LlamaConfig):
28
+ model_type = "llava-jp"
29
+
30
+
31
+ class LlavaLlamaModel(LlavaMetaModel, LlamaModel):
32
+ config_class = LlavaConfig
33
+
34
+ def __init__(self, config: LlamaConfig):
35
+ super(LlavaLlamaModel, self).__init__(config)
36
+
37
+
38
+ class LlavaLlamaForCausalLM(LlamaForCausalLM, LlavaMetaForCausalLM):
39
+ config_class = LlavaConfig
40
+ base_model = "llama"
41
+
42
+ def __init__(self, config):
43
+ super(LlavaLlamaForCausalLM, self).__init__(config)
44
+ self.model = LlavaLlamaModel(config)
45
+ self.vocab_size = config.vocab_size
46
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
47
+
48
+ # Initialize weights and apply final processing
49
+ self.post_init()
50
+
51
+ def get_model(self):
52
+ return self.model
53
+
54
+ def forward(
55
+ self,
56
+ input_ids: torch.LongTensor = None,
57
+ attention_mask: Optional[torch.Tensor] = None,
58
+ position_ids: Optional[torch.LongTensor] = None,
59
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
60
+ inputs_embeds: Optional[torch.FloatTensor] = None,
61
+ labels: Optional[torch.LongTensor] = None,
62
+ use_cache: Optional[bool] = None,
63
+ output_attentions: Optional[bool] = None,
64
+ output_hidden_states: Optional[bool] = None,
65
+ images: Optional[torch.FloatTensor] = None,
66
+ return_dict: Optional[bool] = None,
67
+ **kwargs
68
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
69
+
70
+ if inputs_embeds is None:
71
+ (
72
+ input_ids,
73
+ position_ids,
74
+ attention_mask,
75
+ past_key_values,
76
+ inputs_embeds,
77
+ labels
78
+ ) = self.prepare_inputs_labels_for_multimodal(
79
+ input_ids,
80
+ position_ids,
81
+ attention_mask,
82
+ past_key_values,
83
+ labels,
84
+ images
85
+ )
86
+
87
+ return super().forward(
88
+ input_ids=input_ids,
89
+ attention_mask=attention_mask,
90
+ position_ids=position_ids,
91
+ past_key_values=past_key_values,
92
+ inputs_embeds=inputs_embeds,
93
+ labels=labels,
94
+ use_cache=use_cache,
95
+ output_attentions=output_attentions,
96
+ output_hidden_states=output_hidden_states,
97
+ return_dict=return_dict
98
+ )
99
+
100
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
101
+ images = kwargs.pop("images", None)
102
+ _inputs = super().prepare_inputs_for_generation(
103
+ input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
104
+ )
105
+ if images is not None:
106
+ _inputs['images'] = images
107
+ return _inputs
108
+
109
+ AutoConfig.register("llava-jp", LlavaConfig)
110
+ AutoModelForCausalLM.register(LlavaConfig, LlavaLlamaForCausalLM)
llava/model/vision_projector.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import re
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+
8
+ class IdentityMap(nn.Module):
9
+ def __init__(self):
10
+ super().__init__()
11
+
12
+ def forward(self, x, *args, **kwargs):
13
+ return x
14
+
15
+ @property
16
+ def config(self):
17
+ return {"mm_projector_type": 'identity'}
18
+
19
+
20
+ class FeatureIRLayer(nn.Module):
21
+ def __init__(self, in_dim: int, out_dim: int) -> None:
22
+ super().__init__()
23
+ self.mlp = nn.Sequential(
24
+ nn.Linear(in_dim, out_dim), nn.GELU(), nn.Linear(out_dim, out_dim)
25
+ )
26
+
27
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
28
+ return self.mlp(x)
29
+
30
+ class TokenDownLayer(nn.Module):
31
+ def __init__(self, shape) -> None:
32
+ super().__init__()
33
+ self.dwn = nn.Sequential(
34
+ nn.AdaptiveAvgPool2d(shape)
35
+ )
36
+
37
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
38
+ b, num_tokens, c = x.shape
39
+ h = int(math.sqrt(num_tokens))
40
+ if h * h == num_tokens:
41
+ x = x.permute(0, 2, 1).reshape(b, -1, h, h)
42
+ else:
43
+ # FIXME サイズによっては失敗する
44
+ w = int(num_tokens/h)
45
+ assert w*h == num_tokens
46
+ x = x.permute(0, 2, 1).reshape(b, -1, w, h)
47
+
48
+ x = self.dwn(x)
49
+ x = x.flatten(2).transpose(1, 2)
50
+ return x
51
+
52
+
53
+ class PosInjectLayer(nn.Module):
54
+ # https://github.com/Meituan-AutoML/Twins/blob/main/gvt.py
55
+ def __init__(self, in_dim: int, out_dim: int, stride: int = 1) -> None:
56
+ super().__init__()
57
+ self.peg = nn.Sequential(
58
+ nn.Conv2d(in_dim, out_dim, 3, stride, 1, bias=True, groups=out_dim)
59
+ )
60
+
61
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
62
+ b, num_tokens, c = x.shape
63
+ h = int(math.sqrt(num_tokens))
64
+ assert h * h == num_tokens
65
+ cnn_feat = x.transpose(1, 2).view(b, c, h, h)
66
+ x = self.peg(cnn_feat) + cnn_feat
67
+ x = x.flatten(2).transpose(1, 2)
68
+ return x
69
+
70
+
71
+ class LDPNetV2Projector(nn.Module):
72
+ # https://github.com/Meituan-AutoML/MobileVLM/blob/main/mobilevlm/model/vision_projector.py
73
+ def __init__(self, config=None):
74
+ super().__init__()
75
+ inc, ouc = config.mm_hidden_size, config.hidden_size
76
+ self.mlp = FeatureIRLayer(inc, ouc)
77
+ self.dwn = TokenDownLayer((12, 12))
78
+ self.peg = PosInjectLayer(ouc, ouc, stride=1)
79
+
80
+ def forward(self, x):
81
+ x = self.mlp(x)
82
+ x = self.dwn(x)
83
+ x = self.peg(x)
84
+ return x
85
+
86
+
87
+ def get_vision_projector(config, delay_load=False, **kwargs):
88
+ projector_type = getattr(config, 'mm_projector_type', 'linear')
89
+
90
+ if projector_type == 'linear':
91
+ return nn.Linear(config.mm_hidden_size, config.hidden_size)
92
+ elif projector_type == 'identity':
93
+ return IdentityMap()
94
+ elif projector_type == 'ldpnetv2':
95
+ return LDPNetV2Projector(config)
96
+
97
+ mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type)
98
+ if mlp_gelu_match:
99
+ mlp_depth = int(mlp_gelu_match.group(1))
100
+ modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)]
101
+ for _ in range(1, mlp_depth):
102
+ modules.append(nn.GELU())
103
+ modules.append(nn.Linear(config.hidden_size, config.hidden_size))
104
+ return nn.Sequential(*modules)
105
+
106
+ raise ValueError(f'Unknown projector type: {projector_type}')
llava/s2wrapper/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .core import *
2
+ from .utils import *
llava/s2wrapper/core.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------------------
2
+ # Copyright (c) 2024 Baifeng Shi.
3
+ # All rights reserved.
4
+ #
5
+ # Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
6
+ # ------------------------------------------------------------------------------------------
7
+
8
+ import math
9
+ import torch
10
+ import torch.nn.functional as F
11
+ from einops import rearrange
12
+ from .utils import split_chessboard, merge_chessboard
13
+
14
+ def forward(model, input, scales=None, img_sizes=None, max_split_size=None, resize_output_to_idx=0, num_prefix_token=0,
15
+ output_shape='bnc'):
16
+
17
+ assert input.dim() == 4, "Input image must be in the shape of BxCxHxW."
18
+ assert input.shape[2] == input.shape[3], "Currently only square images are supported."
19
+ assert output_shape in ['bnc', 'bchw'], "Output shape should be either BxNxC (e.g., ViT) or BxCxHxW (e.g., ConvNet)."
20
+ assert output_shape == 'bnc' or num_prefix_token == 0, "For ConvNet there shouldn't be any prefix token."
21
+
22
+ b, c, input_size, _ = input.shape
23
+
24
+ # image size for each scale
25
+ assert scales is not None or img_sizes is not None, "Please assign either scales or img_sizes."
26
+ img_sizes = img_sizes or [int(input_size * scale) for scale in scales]
27
+
28
+ # prepare multiscale inputs
29
+ max_split_size = max_split_size or input_size # The maximum size of each split of image. Set as the input size by default
30
+ num_splits = [math.ceil(size / max_split_size) for size in img_sizes] # number of splits each scale
31
+ input_multiscale = []
32
+ for size, num_split in zip(img_sizes, num_splits):
33
+ x = F.interpolate(input.to(torch.float32), size=size, mode='bicubic').to(input.dtype)
34
+ x = split_chessboard(x, num_split=num_split)
35
+ input_multiscale.append(x)
36
+
37
+ # run feedforward on each scale
38
+ outs_multiscale = [model(x) for x in input_multiscale]
39
+ if num_prefix_token > 0:
40
+ outs_prefix_multiscale = [out[:, :num_prefix_token] for out in outs_multiscale]
41
+ outs_multiscale = [out[:, num_prefix_token:] for out in outs_multiscale]
42
+ if output_shape == 'bnc':
43
+ height = int(outs_multiscale[0].shape[1] ** 0.5)
44
+ if height**2 == outs_multiscale[0].shape[1]:
45
+ width = height
46
+ else:
47
+ width = int(outs_multiscale[0].shape[1]/height)
48
+ assert width*height == outs_multiscale[0].shape[1]
49
+ #print(height, width, outs_multiscale[0].shape[1])
50
+
51
+ # available by siglip
52
+ #outs_multiscale = [rearrange(out, 'b (h w) c -> b c h w', h=int(out.shape[1] ** 0.5), w=int(out.shape[1] ** 0.5))
53
+ # for out in outs_multiscale]
54
+ outs_multiscale = [rearrange(out, 'b (h w) c -> b c h w', h=height, w=width)
55
+ for out in outs_multiscale]
56
+
57
+ # merge outputs of different splits for each scale separately
58
+ outs_multiscale = [merge_chessboard(out, num_split=num_split) for num_split, out in zip(num_splits, outs_multiscale)]
59
+
60
+ # interpolate outputs from different scales and concat together
61
+ #output_size = outs_multiscale[resize_output_to_idx].shape[-2]
62
+ output_size = [height, width]
63
+ out = torch.cat([F.interpolate(outs_multiscale[i].to(torch.float32), size=output_size,
64
+ mode='area').to(outs_multiscale[i].dtype)
65
+ for i in range(len(outs_multiscale))], dim=1)
66
+ if output_shape == 'bnc':
67
+ out = rearrange(out, 'b c h w -> b (h w) c')
68
+ if num_prefix_token > 0:
69
+ # take the mean of prefix tokens from different splits for each scale
70
+ outs_prefix_multiscale = [torch.stack(out.split(b, dim=0), dim=0).mean(dim=0) for out in outs_prefix_multiscale]
71
+ out_prefix_multiscale = torch.cat(outs_prefix_multiscale, dim=-1)
72
+ out = torch.cat([out_prefix_multiscale, out], dim=1)
73
+
74
+ return out
llava/s2wrapper/utils.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------------------
2
+ # Copyright (c) 2024 Baifeng Shi.
3
+ # All rights reserved.
4
+ #
5
+ # Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
6
+ # ------------------------------------------------------------------------------------------
7
+
8
+ import torch
9
+
10
+ def split_chessboard(x, num_split):
11
+ """
12
+ x: b * c * h * w
13
+ Deividing x into num_split**2 sub-squares, and concatenate all the sub-squares on the batch dimension
14
+ """
15
+ B, C, H, W = x.shape
16
+ assert H % num_split == 0 and W % num_split == 0
17
+ h, w = H // num_split, W // num_split
18
+ x_split = torch.cat([x[:, :, i*h:(i+1)*h, j*w:(j+1)*w] for i in range(num_split) for j in range(num_split)], dim=0)
19
+ return x_split
20
+
21
+ def merge_chessboard(x, num_split):
22
+ """
23
+ x: b * c * h * w
24
+ Assuming x contains num_split**2 sub-squares concatenated along batch dimension, merge the sub-squares back to the original whole square.
25
+ (inverse of split_chessboard)
26
+ """
27
+ B, C, H, W = x.shape
28
+ assert B % (num_split**2) == 0
29
+ b = B // (num_split**2)
30
+ x_merge = torch.cat([torch.cat([x[(i*num_split + j)*b:(i*num_split + j + 1)*b] for j in range(num_split)], dim=-1)
31
+ for i in range(num_split)], dim=-2)
32
+ return x_merge
llava/train/arguments_dataclass.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass, field
4
+ from typing import Optional
5
+
6
+ import transformers
7
+
8
+
9
+ @dataclass
10
+ class ModelArguments:
11
+ base_model: Optional[str] = field(default="gpt2",
12
+ metadata={"help": "gpt2 or gpt_neox or llama"})
13
+ model_name_or_path: Optional[str] = field(default="rinna/japanese-gpt2-xsmall")
14
+ version: Optional[str] = field(default="plain")
15
+ freeze_backbone: bool = field(default=False) # LLMをFreezeするか
16
+ tune_mm_mlp_adapter: bool = field(default=False) # 事前学習のときはmm_mlp_adapterだけ保存する.
17
+ vision_tower: Optional[str] = field(default="openai/clip-vit-large-patch14-336")
18
+ mm_vision_select_layer: Optional[int] = field(default=-2) # default to the last two layer
19
+ pretrain_mm_mlp_adapter: Optional[str] = field(default=None) # fine-tuningのときには設定
20
+ mm_projector_type: Optional[str] = field(default='mlp2x_gelu') # 2層の線形層
21
+ mm_vision_select_feature: Optional[str] = field(default="patch")
22
+ scales: Optional[list[float]] = field(default=None)
23
+
24
+
25
+ @dataclass
26
+ class DataArguments:
27
+ data_path: str = field(default="",
28
+ metadata={"help": "Path to the training data."})
29
+ lazy_preprocess: bool = False
30
+ is_multimodal: bool = False
31
+ image_folder: Optional[str] = field(default="/home/toshi/work/llava_jp/input/LLaVA-CC3M-Pretrain-595K/images",
32
+ metadata={"help": "Path to image data."})
33
+ image_aspect_ratio: str = 'square'
34
+ image_size: Optional[int] = None
35
+
36
+
37
+ @dataclass
38
+ class TrainingArguments(transformers.TrainingArguments):
39
+ cache_dir: Optional[str] = field(default=None)
40
+ optim: str = field(default="adamw_torch")
41
+ model_max_length: int = field(
42
+ default=1024,
43
+ metadata={
44
+ "help":
45
+ "Maximum sequence length. Sequences will be right padded (and possibly truncated)."
46
+ },
47
+ )
48
+ double_quant: bool = field(
49
+ default=True,
50
+ metadata={"help": "Compress the quantization statistics through double quantization."}
51
+ )
52
+ quant_type: str = field(
53
+ default="nf4",
54
+ metadata={"help": "Quantization data type to use. Should be one of `fp4` or `nf4`."}
55
+ )
56
+ bits: int = field(
57
+ default=16,
58
+ metadata={"help": "How many bits to use."}
59
+ )
60
+ lora_enable: bool = False
61
+ lora_r: int = 64
62
+ lora_alpha: int = 16
63
+ lora_dropout: float = 0.05
64
+ lora_weight_path: str = ""
65
+ lora_bias: str = "none"
66
+ mm_projector_lr: Optional[float] = None
67
+ group_by_modality_length: bool = field(default=False) # dataset sampler option
68
+
69
+ fp16: bool = field(default=False)
70
+ bf16: bool = field(default=False)
71
+ output_dir: str = field(default="./output_llava/checkpoints/llava-v1.5-japanese-gpt2-xsmall")
72
+ num_train_epochs: int = field(default=1)
73
+ per_device_train_batch_size: int = field(default=32)
74
+ per_device_eval_batch_size: int = field(default=4)
75
+ gradient_accumulation_steps: int = field(default=1)
76
+ evaluation_strategy: str = field(default="no")
77
+ save_strategy: str = field(default="steps")
78
+ save_steps: int = field(default=24000)
79
+ save_total_limit: int = field(default=1)
80
+ learning_rate: float = field(default=1e-3)
81
+ weight_decay: float = field(default=0.)
82
+ warmup_ratio: float = field(default=0.03)
83
+ logging_steps: int = field(default=1)
84
+ model_max_length: int = field(default=1024)
85
+ gradient_checkpointing: bool = field(default=True)
86
+ dataloader_num_workers: int = field(default=16)
87
+ lr_scheduler_type: str = field(default="cosine")
88
+ seed: int = field(default=42)
llava/train/dataset.py ADDED
@@ -0,0 +1,306 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import json
3
+ import os
4
+
5
+ from dataclasses import dataclass
6
+ from typing import Dict
7
+
8
+ from typing import Sequence
9
+
10
+ import torch
11
+ import transformers
12
+
13
+ from PIL import Image
14
+ from torch.utils.data import Dataset
15
+
16
+ from llava import conversation as conversation_lib
17
+ from llava.constants import DEFAULT_IMAGE_TOKEN, IGNORE_INDEX, IMAGE_TOKEN_INDEX
18
+ from llava.train.arguments_dataclass import DataArguments
19
+
20
+
21
+ def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None):
22
+ prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('<image>')]
23
+
24
+ def insert_separator(X, sep):
25
+ return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1]
26
+
27
+ input_ids = []
28
+ offset = 0
29
+ if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
30
+ offset = 1
31
+ input_ids.append(prompt_chunks[0][0])
32
+
33
+ for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
34
+ input_ids.extend(x[offset:])
35
+
36
+ if return_tensors is not None:
37
+ if return_tensors == 'pt':
38
+ return torch.tensor(input_ids, dtype=torch.long)
39
+ raise ValueError(f'Unsupported tensor type: {return_tensors}')
40
+ return input_ids
41
+
42
+
43
+ def preprocess_multimodal(
44
+ sources: Sequence[str],
45
+ data_args: DataArguments
46
+ ) -> Dict:
47
+ is_multimodal = data_args.is_multimodal
48
+ if not is_multimodal:
49
+ return sources
50
+
51
+ for source in sources:
52
+ for sentence in source:
53
+ if DEFAULT_IMAGE_TOKEN in sentence['value']:
54
+ sentence['value'] = sentence['value'].replace(DEFAULT_IMAGE_TOKEN, '').strip()
55
+ sentence['value'] = DEFAULT_IMAGE_TOKEN + '\n' + sentence['value']
56
+ sentence['value'] = sentence['value'].strip()
57
+ replace_token = DEFAULT_IMAGE_TOKEN
58
+ sentence["value"] = sentence["value"].replace(DEFAULT_IMAGE_TOKEN, replace_token)
59
+
60
+ return sources
61
+
62
+
63
+ def preprocess_plain(
64
+ sources: Sequence[str],
65
+ tokenizer: transformers.PreTrainedTokenizer,
66
+ ) -> Dict:
67
+ # add end signal and concatenate together
68
+ conversations = []
69
+ for source in sources:
70
+ assert len(source) == 2
71
+ assert DEFAULT_IMAGE_TOKEN in source[0]['value']
72
+ source[0]['value'] = DEFAULT_IMAGE_TOKEN
73
+ conversation = source[0]['value'] + source[1]['value'] + conversation_lib.default_conversation.sep
74
+ conversations.append(conversation)
75
+ # tokenize conversations
76
+ input_ids = [tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations]
77
+ targets = copy.deepcopy(input_ids)
78
+ for target, source in zip(targets, sources):
79
+ tokenized_len = len(tokenizer_image_token(source[0]['value'], tokenizer))
80
+ target[:tokenized_len] = IGNORE_INDEX
81
+
82
+ return dict(input_ids=input_ids, labels=targets)
83
+
84
+
85
+ def preprocess_v1(
86
+ sources,
87
+ tokenizer: transformers.PreTrainedTokenizer,
88
+ has_image: bool = False
89
+ ) -> Dict:
90
+ conv = conversation_lib.default_conversation.copy()
91
+ roles = {"ユーザー": conv.roles[0], "システム": conv.roles[1]}
92
+
93
+ # Apply prompt templates
94
+ conversations = []
95
+ for i, source in enumerate(sources):
96
+ if roles[source[0]["from"]] != conv.roles[0]:
97
+ # Skip the first one if it is not from human
98
+ source = source[1:]
99
+
100
+ conv.messages = []
101
+ for j, sentence in enumerate(source):
102
+ role = roles[sentence["from"]]
103
+ assert role == conv.roles[j % 2], f"{i}"
104
+ conv.append_message(role, sentence["value"])
105
+ conversations.append(conv.get_prompt())
106
+ # Tokenize conversations
107
+
108
+ if has_image:
109
+ input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0)
110
+ else:
111
+ input_ids = tokenizer(
112
+ conversations,
113
+ return_tensors="pt",
114
+ padding="longest",
115
+ max_length=tokenizer.model_max_length,
116
+ truncation=True,
117
+ ).input_ids
118
+
119
+ targets = input_ids.clone()
120
+
121
+ assert conv.sep_style == conversation_lib.SeparatorStyle.TWO
122
+
123
+ # Mask targets
124
+ sep = conv.sep + conv.roles[1] + ": "
125
+ for conversation, target in zip(conversations, targets):
126
+ total_len = int(target.ne(tokenizer.pad_token_id).sum())
127
+
128
+ rounds = conversation.split(conv.sep2)
129
+ cur_len = 0 #1
130
+ target[:cur_len] = IGNORE_INDEX
131
+ for i, rou in enumerate(rounds):
132
+ if rou == "":
133
+ break
134
+
135
+ parts = rou.split(sep)
136
+ if len(parts) != 2:
137
+ break
138
+ parts[0] += sep
139
+
140
+ if has_image:
141
+ round_len = len(tokenizer_image_token(rou, tokenizer))
142
+ instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2
143
+ else:
144
+ round_len = len(tokenizer(rou).input_ids)
145
+ instruction_len = len(tokenizer(parts[0]).input_ids) - 2
146
+
147
+ target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
148
+ cur_len += round_len
149
+ target[cur_len:] = IGNORE_INDEX
150
+
151
+ return dict(
152
+ input_ids=input_ids,
153
+ labels=targets,
154
+ )
155
+
156
+
157
+ def preprocess(
158
+ sources: Sequence[str],
159
+ tokenizer: transformers.PreTrainedTokenizer,
160
+ has_image: bool = False
161
+ ) -> Dict:
162
+ """
163
+ Given a list of sources, each is a conversation list. This transform:
164
+ 1. Add signal '### ' at the beginning each sentence, with end signal '\n';
165
+ 2. Concatenate conversations together;
166
+ 3. Tokenize the concatenated conversation;
167
+ 4. Make a deepcopy as the target. Mask human words with IGNORE_INDEX.
168
+ """
169
+ if conversation_lib.default_conversation.sep_style == conversation_lib.SeparatorStyle.PLAIN:
170
+ return preprocess_plain(sources, tokenizer)
171
+ elif conversation_lib.default_conversation.sep_style == conversation_lib.SeparatorStyle.TWO:
172
+ return preprocess_v1(sources, tokenizer, has_image)
173
+ else:
174
+ raise ValueError(f"Invalid style: {conversation_lib.default_conversation.sep_style}")
175
+
176
+
177
+ class LazySupervisedDataset(Dataset):
178
+ """Dataset for supervised fine-tuning."""
179
+
180
+ def __init__(
181
+ self, data_path: str,
182
+ tokenizer: transformers.PreTrainedTokenizer,
183
+ data_args: DataArguments,
184
+ ):
185
+ super(LazySupervisedDataset, self).__init__()
186
+
187
+ list_data_dict = json.load(open(data_path, "r"))
188
+
189
+ from pathlib import Path
190
+
191
+ print("Formatting inputs...Skip in lazy mode")
192
+ self.tokenizer = tokenizer
193
+ self.list_data_dict = [i for i in list_data_dict if Path(data_args.image_folder, i['image']).is_file()]
194
+ self.data_args = data_args
195
+
196
+ def __len__(self):
197
+ return len(self.list_data_dict)
198
+
199
+ @property
200
+ def lengths(self):
201
+ length_list = []
202
+ for sample in self.list_data_dict:
203
+ img_tokens = 128 if 'image' in sample else 0
204
+ length_list.append(sum(len(conv['value'].split()) for conv in sample['conversations']) + img_tokens)
205
+ return length_list
206
+
207
+ @property
208
+ def modality_lengths(self):
209
+ length_list = []
210
+ for sample in self.list_data_dict:
211
+ cur_len = sum(len(conv['value'].split()) for conv in sample['conversations'])
212
+ cur_len = cur_len if 'images' in sample else -cur_len
213
+ length_list.append(cur_len)
214
+ return length_list
215
+
216
+ def __getitem__(self, i) -> Dict[str, torch.Tensor]:
217
+ sources = self.list_data_dict[i]
218
+ if isinstance(i, int):
219
+ sources = [sources]
220
+ assert len(sources) == 1, "Don't know why it is wrapped to a list" # FIXME
221
+ if 'image' in sources[0]:
222
+ image_file = self.list_data_dict[i]['image']
223
+ image_folder = self.data_args.image_folder
224
+ processor = self.data_args.image_processor
225
+ image = Image.open(os.path.join(image_folder, image_file)).convert('RGB')
226
+ if self.data_args.image_aspect_ratio == 'pad':
227
+ def expand2square(pil_img, background_color):
228
+ width, height = pil_img.size
229
+ if width == height:
230
+ return pil_img
231
+ elif width > height:
232
+ result = Image.new(pil_img.mode, (width, width), background_color)
233
+ result.paste(pil_img, (0, (width - height) // 2))
234
+ return result
235
+ else:
236
+ result = Image.new(pil_img.mode, (height, height), background_color)
237
+ result.paste(pil_img, ((height - width) // 2, 0))
238
+ return result
239
+ image = expand2square(image, tuple(int(x*255) for x in processor.image_mean))
240
+ image = processor.preprocess(
241
+ image,
242
+ return_tensors='pt',
243
+ size={"height": self.data_args.image_size, "width": self.data_args.image_size}
244
+ )['pixel_values'][0]
245
+ else:
246
+ image = processor.preprocess(
247
+ image,
248
+ return_tensors='pt',
249
+ size={"height": self.data_args.image_size, "width": self.data_args.image_size}
250
+ )['pixel_values'][0]
251
+ sources = preprocess_multimodal(
252
+ copy.deepcopy([e["conversations"] for e in sources]),
253
+ self.data_args
254
+ )
255
+ else:
256
+ sources = copy.deepcopy([e["conversations"] for e in sources])
257
+ data_dict = preprocess(
258
+ sources,
259
+ self.tokenizer,
260
+ has_image=('image' in self.list_data_dict[i]))
261
+ if isinstance(i, int):
262
+ data_dict = dict(input_ids=data_dict["input_ids"][0],
263
+ labels=data_dict["labels"][0])
264
+
265
+ # image exist in the data
266
+ if 'image' in self.list_data_dict[i]:
267
+ data_dict['images'] = image
268
+ elif self.data_args.is_multimodal:
269
+ # image does not exist in the data, but the model is multimodal
270
+ crop_size = self.data_args.image_processor.crop_size
271
+ data_dict['images'] = torch.zeros(3, crop_size['height'], crop_size['width'])
272
+ return data_dict
273
+
274
+
275
+ @dataclass
276
+ class DataCollatorForSupervisedDataset(object):
277
+ """Collate examples for supervised fine-tuning."""
278
+
279
+ tokenizer: transformers.PreTrainedTokenizer
280
+
281
+ def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
282
+ input_ids, labels = tuple([instance[key] for instance in instances]
283
+ for key in ("input_ids", "labels"))
284
+ input_ids = torch.nn.utils.rnn.pad_sequence(
285
+ input_ids,
286
+ batch_first=True,
287
+ padding_value=self.tokenizer.pad_token_id)
288
+ labels = torch.nn.utils.rnn.pad_sequence(labels,
289
+ batch_first=True,
290
+ padding_value=IGNORE_INDEX)
291
+ input_ids = input_ids[:, :self.tokenizer.model_max_length]
292
+ labels = labels[:, :self.tokenizer.model_max_length]
293
+ batch = dict(
294
+ input_ids=input_ids,
295
+ labels=labels,
296
+ attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
297
+ )
298
+
299
+ if 'images' in instances[0]:
300
+ images = [instance['images'] for instance in instances]
301
+ if all(x is not None and x.shape == images[0].shape for x in images):
302
+ batch['images'] = torch.stack(images)
303
+ else:
304
+ batch['images'] = images
305
+
306
+ return batch
llava/train/llava_trainer.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List, Optional
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+ from torch.utils.data import Sampler
8
+ from transformers import Trainer
9
+ from transformers.trainer import (
10
+ get_parameter_names,
11
+ has_length,
12
+ ALL_LAYERNORM_LAYERS,
13
+ logger,
14
+ )
15
+
16
+
17
+ def split_to_even_chunks(indices, lengths, num_chunks):
18
+ """
19
+ Split a list of indices into `chunks` chunks of roughly equal lengths.
20
+ """
21
+
22
+ if len(indices) % num_chunks != 0:
23
+ return [indices[i::num_chunks] for i in range(num_chunks)]
24
+
25
+ num_indices_per_chunk = len(indices) // num_chunks
26
+
27
+ chunks = [[] for _ in range(num_chunks)]
28
+ chunks_lengths = [0 for _ in range(num_chunks)]
29
+ for index in indices:
30
+ shortest_chunk = chunks_lengths.index(min(chunks_lengths))
31
+ chunks[shortest_chunk].append(index)
32
+ chunks_lengths[shortest_chunk] += lengths[index]
33
+ if len(chunks[shortest_chunk]) == num_indices_per_chunk:
34
+ chunks_lengths[shortest_chunk] = float("inf")
35
+
36
+ return chunks
37
+
38
+
39
+ def get_modality_length_grouped_indices(lengths, batch_size, world_size, generator=None):
40
+ # We need to use torch for the random part as a distributed sampler will set the random seed for torch.
41
+ assert all(l != 0 for l in lengths), "Should not have zero length."
42
+ if all(l > 0 for l in lengths) or all(l < 0 for l in lengths):
43
+ # all samples are in the same modality
44
+ return get_length_grouped_indices(lengths, batch_size, world_size, generator=generator)
45
+ mm_indices, mm_lengths = zip(*[(i, l) for i, l in enumerate(lengths) if l > 0])
46
+ lang_indices, lang_lengths = zip(*[(i, -l) for i, l in enumerate(lengths) if l < 0])
47
+
48
+ mm_shuffle = [mm_indices[i] for i in get_length_grouped_indices(mm_lengths, batch_size, world_size, generator=None)]
49
+ lang_shuffle = [lang_indices[i] for i in get_length_grouped_indices(lang_lengths, batch_size, world_size, generator=None)]
50
+ megabatch_size = world_size * batch_size
51
+ mm_megabatches = [mm_shuffle[i : i + megabatch_size] for i in range(0, len(mm_shuffle), megabatch_size)]
52
+ lang_megabatches = [lang_shuffle[i : i + megabatch_size] for i in range(0, len(lang_shuffle), megabatch_size)]
53
+
54
+ last_mm = mm_megabatches[-1]
55
+ last_lang = lang_megabatches[-1]
56
+ additional_batch = last_mm + last_lang
57
+ megabatches = mm_megabatches[:-1] + lang_megabatches[:-1]
58
+ megabatch_indices = torch.randperm(len(megabatches), generator=generator)
59
+ megabatches = [megabatches[i] for i in megabatch_indices]
60
+
61
+ if len(additional_batch) > 0:
62
+ megabatches.append(sorted(additional_batch))
63
+
64
+ return [i for megabatch in megabatches for i in megabatch]
65
+
66
+
67
+ def get_length_grouped_indices(lengths, batch_size, world_size, generator=None, merge=True):
68
+ # We need to use torch for the random part as a distributed sampler will set the random seed for torch.
69
+ indices = torch.randperm(len(lengths), generator=generator)
70
+ megabatch_size = world_size * batch_size
71
+ megabatches = [indices[i : i + megabatch_size].tolist() for i in range(0, len(lengths), megabatch_size)]
72
+ megabatches = [sorted(megabatch, key=lambda i: lengths[i], reverse=True) for megabatch in megabatches]
73
+ megabatches = [split_to_even_chunks(megabatch, lengths, world_size) for megabatch in megabatches]
74
+
75
+ return [i for megabatch in megabatches for batch in megabatch for i in batch]
76
+
77
+
78
+ class LengthGroupedSampler(Sampler):
79
+ # fine-tuningのときだけ使っているみたい
80
+ r"""
81
+ Sampler that samples indices in a way that groups together features of the dataset of roughly the same length while
82
+ keeping a bit of randomness.
83
+ """
84
+
85
+ def __init__(
86
+ self,
87
+ batch_size: int,
88
+ world_size: int,
89
+ lengths: Optional[List[int]] = None,
90
+ generator=None,
91
+ group_by_modality: bool = False,
92
+ ):
93
+ if lengths is None:
94
+ raise ValueError("Lengths must be provided.")
95
+
96
+ self.batch_size = batch_size
97
+ self.world_size = world_size
98
+ self.lengths = lengths
99
+ self.generator = generator
100
+ self.group_by_modality = group_by_modality
101
+
102
+ def __len__(self):
103
+ return len(self.lengths)
104
+
105
+ def __iter__(self):
106
+ if self.group_by_modality:
107
+ indices = get_modality_length_grouped_indices(self.lengths, self.batch_size, self.world_size, generator=self.generator)
108
+ else:
109
+ indices = get_length_grouped_indices(self.lengths, self.batch_size, self.world_size, generator=self.generator)
110
+ return iter(indices)
111
+
112
+
113
+ def get_mm_adapter_state(named_params, keys_to_match):
114
+ to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)}
115
+ to_return = {k: v.detach().cpu().clone() for k, v in to_return.items()}
116
+ return to_return
117
+
118
+
119
+ class LLaVATrainer(Trainer):
120
+
121
+ def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
122
+ if self.train_dataset is None or not has_length(self.train_dataset):
123
+ return None
124
+
125
+ if self.args.group_by_modality_length:
126
+ lengths = self.train_dataset.modality_lengths
127
+ return LengthGroupedSampler(
128
+ self.args.train_batch_size,
129
+ world_size=self.args.world_size * self.args.gradient_accumulation_steps,
130
+ lengths=lengths,
131
+ group_by_modality=True,
132
+ )
133
+ else:
134
+ return super()._get_train_sampler()
135
+
136
+ def create_optimizer(self):
137
+ """
138
+ Setup the optimizer.
139
+
140
+ We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
141
+ Trainer's init through `optimizers`, or subclass and override this method in a subclass.
142
+ """
143
+ opt_model = self.model
144
+
145
+ if self.optimizer is None:
146
+ decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS)
147
+ decay_parameters = [name for name in decay_parameters if "bias" not in name]
148
+ if self.args.mm_projector_lr is not None:
149
+ projector_parameters = [name for name, _ in opt_model.named_parameters() if "mm_projector" in name]
150
+ optimizer_grouped_parameters = [
151
+ {
152
+ "params": [
153
+ p for n, p in opt_model.named_parameters() if (n in decay_parameters and n not in projector_parameters and p.requires_grad)
154
+ ],
155
+ "weight_decay": self.args.weight_decay,
156
+ },
157
+ {
158
+ "params": [
159
+ p for n, p in opt_model.named_parameters() if (n not in decay_parameters and n not in projector_parameters and p.requires_grad)
160
+ ],
161
+ "weight_decay": 0.0,
162
+ },
163
+ {
164
+ "params": [
165
+ p for n, p in opt_model.named_parameters() if (n in decay_parameters and n in projector_parameters and p.requires_grad)
166
+ ],
167
+ "weight_decay": self.args.weight_decay,
168
+ "lr": self.args.mm_projector_lr,
169
+ },
170
+ {
171
+ "params": [
172
+ p for n, p in opt_model.named_parameters() if (n not in decay_parameters and n in projector_parameters and p.requires_grad)
173
+ ],
174
+ "weight_decay": 0.0,
175
+ "lr": self.args.mm_projector_lr,
176
+ },
177
+ ]
178
+ else:
179
+ optimizer_grouped_parameters = [
180
+ {
181
+ "params": [
182
+ p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad)
183
+ ],
184
+ "weight_decay": self.args.weight_decay,
185
+ },
186
+ {
187
+ "params": [
188
+ p for n, p in opt_model.named_parameters() if (n not in decay_parameters and p.requires_grad)
189
+ ],
190
+ "weight_decay": 0.0,
191
+ },
192
+ ]
193
+
194
+ optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args)
195
+
196
+ self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
197
+ if optimizer_cls.__name__ == "Adam8bit":
198
+ import bitsandbytes
199
+
200
+ manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
201
+
202
+ skipped = 0
203
+ for module in opt_model.modules():
204
+ if isinstance(module, nn.Embedding):
205
+ skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values())
206
+ logger.info(f"skipped {module}: {skipped/2**20}M params")
207
+ manager.register_module_override(module, "weight", {"optim_bits": 32})
208
+ logger.debug(f"bitsandbytes: will optimize {module} in fp32")
209
+ logger.info(f"skipped: {skipped/2**20}M params")
210
+
211
+ return self.optimizer
212
+
213
+ def _save_checkpoint(self, model, trial, metrics=None):
214
+ if getattr(self.args, 'tune_mm_mlp_adapter', False):
215
+ from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
216
+ checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
217
+
218
+ run_dir = self._get_output_dir(trial=trial)
219
+ output_dir = os.path.join(run_dir, checkpoint_folder)
220
+
221
+ # Only save Adapter
222
+ #keys_to_match = ['mm_projector', 'vision_resampler']
223
+ keys_to_match = ['mm_projector']
224
+ weight_to_save = get_mm_adapter_state(self.model.named_parameters(), keys_to_match)
225
+ #weight_to_save = self.model.named_parameters().detach().cpu().clone()
226
+
227
+ if self.args.local_rank == 0 or self.args.local_rank == -1:
228
+ self.model.config.save_pretrained(output_dir)
229
+ torch.save(weight_to_save, os.path.join(output_dir, f'mm_projector.bin'))
230
+ else:
231
+ super(LLaVATrainer, self)._save_checkpoint(model, trial, metrics)
232
+
233
+ def _save(self, output_dir: Optional[str] = None, state_dict=None):
234
+ if getattr(self.args, 'tune_mm_mlp_adapter', False):
235
+ pass
236
+ else:
237
+ super(LLaVATrainer, self)._save(output_dir, state_dict)