myownskyW7 commited on
Commit
9a3a17c
1 Parent(s): afa37a8

Add fine-tuning code

Browse files
modeling_InternLM.py CHANGED
@@ -6,8 +6,6 @@ import torch
6
  import torch.utils.checkpoint
7
  import torch.utils.checkpoint
8
  from einops import rearrange
9
- #import rotary_emb
10
- #from flash_attn.layers.rotary import ApplyRotaryEmbQKV_ as LegacyApplyRotaryEmbQKV_
11
  from torch import nn
12
  from torch.nn import CrossEntropyLoss
13
  from transformers.activations import ACT2FN
@@ -18,14 +16,17 @@ from transformers.utils import logging
18
  from .configuration_InternLM_XComposer import InternLMXComposerConfig
19
  from .modeling_utils import LoRALinear
20
 
 
 
 
 
 
21
  logger = logging.get_logger(__name__)
22
 
23
  _CONFIG_FOR_DOC = "InternLMXComposerConfig"
24
 
25
 
26
  def rotary_embed(x1, x2, cos, sin, conj):
27
- # print(x1.shape, x2.shape, cos.shape, sin.shape)
28
- #[5, 1, 32, 64] [1, 1, 64]
29
  x1, x2 = x1.float(), x2.float()
30
  if conj:
31
  x1, x2 = x1 * cos + x2 * sin, x1 * sin + x2 * cos
@@ -35,7 +36,6 @@ def rotary_embed(x1, x2, cos, sin, conj):
35
 
36
 
37
  class LegacyApplyRotaryEmbQKV_(torch.autograd.Function):
38
-
39
  @staticmethod
40
  def forward(ctx, qkv, cos, sin, cos_k=None, sin_k=None, interleaved=False):
41
  """
@@ -55,18 +55,26 @@ class LegacyApplyRotaryEmbQKV_(torch.autograd.Function):
55
  assert seqlen <= rotary_seqlen
56
  cos_k = cos if cos_k is None else cos_k
57
  sin_k = sin if sin_k is None else sin_k
58
- assert sin.shape == cos_k.shape == sin_k.shape == (rotary_seqlen, rotary_dim // 2)
 
59
  q_ro = qkv[:, :, 0, :, :rotary_dim]
60
- q1, q2 = q_ro.chunk(2, dim=-1) if not interleaved else (q_ro[..., ::2], q_ro[..., 1::2])
 
 
61
  # rotary_emb.apply_rotary(q1, q2, rearrange(cos[:seqlen], 's d -> s 1 d'),
62
  # rearrange(sin[:seqlen], 's d -> s 1 d'), q1, q2, False)
63
- q1, q2 = rotary_embed(q1, q2, rearrange(cos[:seqlen], 's d -> s 1 d'), rearrange(sin[:seqlen], 's d -> s 1 d'), False)
 
64
  qkv[:, :, 0, :, :rotary_dim] = torch.cat([q1, q2], dim=-1)
65
  k_ro = qkv[:, :, 1, :, :rotary_dim]
66
- k1, k2 = k_ro.chunk(2, dim=-1) if not interleaved else (k_ro[..., ::2], k_ro[..., 1::2])
 
 
67
  # rotary_emb.apply_rotary(k1, k2, rearrange(cos_k[:seqlen], 's d -> s 1 d'),
68
  # rearrange(sin_k[:seqlen], 's d -> s 1 d'), k1, k2, False)
69
- k1, k2 = rotary_embed(k1, k2, rearrange(cos_k[:seqlen], 's d -> s 1 d'), rearrange(sin_k[:seqlen], 's d -> s 1 d'), False)
 
 
70
  qkv[:, :, 1, :, :rotary_dim] = torch.cat([k1, k2], dim=-1)
71
  ctx.save_for_backward(cos, sin, cos_k, sin_k)
72
  ctx.interleaved = interleaved
@@ -79,18 +87,69 @@ class LegacyApplyRotaryEmbQKV_(torch.autograd.Function):
79
  rotary_dim = cos.shape[-1]
80
  rotary_dim *= 2
81
  dq_ro = dqkv[:, :, 0, :, :rotary_dim]
82
- dq1, dq2 = (dq_ro.chunk(2, dim=-1) if not ctx.interleaved
83
- else (dq_ro[..., ::2], dq_ro[..., 1::2]))
84
- rotary_emb.apply_rotary(dq1, dq2, rearrange(cos[:seqlen], 's d -> s 1 d'),
85
- rearrange(sin[:seqlen], 's d -> s 1 d'), dq1, dq2, True)
 
 
86
  dk_ro = dqkv[:, :, 1, :, :rotary_dim]
87
- dk1, dk2 = (dk_ro.chunk(2, dim=-1) if not ctx.interleaved
88
- else (dk_ro[..., ::2], dk_ro[..., 1::2]))
89
- rotary_emb.apply_rotary(dk1, dk2, rearrange(cos_k[:seqlen], 's d -> s 1 d'),
90
- rearrange(sin_k[:seqlen], 's d -> s 1 d'), dk1, dk2, True)
 
 
91
  return dqkv, None, None, None, None, None
92
 
93
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  class ConvertedInternLMRotaryEmbedding(torch.nn.Module):
95
  def __init__(self, dim: int, base=10000, scale_base=0, device=None):
96
  """ """
@@ -141,6 +200,23 @@ class ConvertedInternLMRotaryEmbedding(torch.nn.Module):
141
  self._cos_k_cached = (torch.cos(freqs) / scale).to(x.dtype)
142
  self._sin_k_cached = (torch.sin(freqs) / scale).to(x.dtype)
143
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
  def eval_forward(self, qkv, seqlen_offset=0):
145
  """
146
  seqlen_offset: can be used in generation where the qkv being passed in is only the last
@@ -161,6 +237,7 @@ class ConvertedInternLMRotaryEmbedding(torch.nn.Module):
161
  )
162
 
163
 
 
164
  legacy_apply_rotary_embed_qkv = LegacyApplyRotaryEmbQKV_.apply
165
 
166
 
@@ -1245,6 +1322,6 @@ class InternLMForCausalLM(InternLMPreTrainedModel):
1245
  reordered_past = ()
1246
  for layer_past in past_key_values:
1247
  reordered_past += (tuple(
1248
- past_state.index_select(0, beam_idx)
1249
  for past_state in layer_past), )
1250
  return reordered_past
 
6
  import torch.utils.checkpoint
7
  import torch.utils.checkpoint
8
  from einops import rearrange
 
 
9
  from torch import nn
10
  from torch.nn import CrossEntropyLoss
11
  from transformers.activations import ACT2FN
 
16
  from .configuration_InternLM_XComposer import InternLMXComposerConfig
17
  from .modeling_utils import LoRALinear
18
 
19
+ try:
20
+ import rotary_emb
21
+ except Exception as e:
22
+ print('Please following docs/install.md to install rotary_emb if you want to do fine-tuning')
23
+
24
  logger = logging.get_logger(__name__)
25
 
26
  _CONFIG_FOR_DOC = "InternLMXComposerConfig"
27
 
28
 
29
  def rotary_embed(x1, x2, cos, sin, conj):
 
 
30
  x1, x2 = x1.float(), x2.float()
31
  if conj:
32
  x1, x2 = x1 * cos + x2 * sin, x1 * sin + x2 * cos
 
36
 
37
 
38
  class LegacyApplyRotaryEmbQKV_(torch.autograd.Function):
 
39
  @staticmethod
40
  def forward(ctx, qkv, cos, sin, cos_k=None, sin_k=None, interleaved=False):
41
  """
 
55
  assert seqlen <= rotary_seqlen
56
  cos_k = cos if cos_k is None else cos_k
57
  sin_k = sin if sin_k is None else sin_k
58
+ assert sin.shape == cos_k.shape == sin_k.shape == (rotary_seqlen,
59
+ rotary_dim // 2)
60
  q_ro = qkv[:, :, 0, :, :rotary_dim]
61
+ q1, q2 = q_ro.chunk(2, dim=-1) if not interleaved else (q_ro[..., ::2],
62
+ q_ro[...,
63
+ 1::2])
64
  # rotary_emb.apply_rotary(q1, q2, rearrange(cos[:seqlen], 's d -> s 1 d'),
65
  # rearrange(sin[:seqlen], 's d -> s 1 d'), q1, q2, False)
66
+ q1, q2 = rotary_embed(q1, q2, rearrange(cos[:seqlen], 's d -> s 1 d'),
67
+ rearrange(sin[:seqlen], 's d -> s 1 d'), False)
68
  qkv[:, :, 0, :, :rotary_dim] = torch.cat([q1, q2], dim=-1)
69
  k_ro = qkv[:, :, 1, :, :rotary_dim]
70
+ k1, k2 = k_ro.chunk(2, dim=-1) if not interleaved else (k_ro[..., ::2],
71
+ k_ro[...,
72
+ 1::2])
73
  # rotary_emb.apply_rotary(k1, k2, rearrange(cos_k[:seqlen], 's d -> s 1 d'),
74
  # rearrange(sin_k[:seqlen], 's d -> s 1 d'), k1, k2, False)
75
+ k1, k2 = rotary_embed(k1, k2, rearrange(cos_k[:seqlen],
76
+ 's d -> s 1 d'),
77
+ rearrange(sin_k[:seqlen], 's d -> s 1 d'), False)
78
  qkv[:, :, 1, :, :rotary_dim] = torch.cat([k1, k2], dim=-1)
79
  ctx.save_for_backward(cos, sin, cos_k, sin_k)
80
  ctx.interleaved = interleaved
 
87
  rotary_dim = cos.shape[-1]
88
  rotary_dim *= 2
89
  dq_ro = dqkv[:, :, 0, :, :rotary_dim]
90
+ dq1, dq2 = (dq_ro.chunk(2, dim=-1) if not ctx.interleaved else
91
+ (dq_ro[..., ::2], dq_ro[..., 1::2]))
92
+ rotary_emb.apply_rotary(dq1, dq2,
93
+ rearrange(cos[:seqlen], 's d -> s 1 d'),
94
+ rearrange(sin[:seqlen], 's d -> s 1 d'), dq1,
95
+ dq2, True)
96
  dk_ro = dqkv[:, :, 1, :, :rotary_dim]
97
+ dk1, dk2 = (dk_ro.chunk(2, dim=-1) if not ctx.interleaved else
98
+ (dk_ro[..., ::2], dk_ro[..., 1::2]))
99
+ rotary_emb.apply_rotary(dk1, dk2,
100
+ rearrange(cos_k[:seqlen], 's d -> s 1 d'),
101
+ rearrange(sin_k[:seqlen], 's d -> s 1 d'), dk1,
102
+ dk2, True)
103
  return dqkv, None, None, None, None, None
104
 
105
 
106
+ class ApplyRotaryEmbQKV_(torch.autograd.Function):
107
+ """
108
+ ApplyRotaryEmbQKV_
109
+ """
110
+ @staticmethod
111
+ def forward(ctx, qkv, cos, sin, cos_k=None, sin_k=None):
112
+ """
113
+ qkv: (total, 3, nheads, headdim)
114
+ cos, sin: (seqlen, rotary_dim / 2)
115
+ cos_k, sin_k: (seqlen, rotary_dim / 2), optional
116
+ rotary_dim must be <= headdim
117
+ Apply rotary embedding *inplace* to the first rotary_dim of q and k.
118
+ """
119
+ _, three, _, headdim = qkv.shape
120
+ assert three == 3
121
+ rotary_seqlen, rotary_dim = cos.shape
122
+ rotary_dim *= 2
123
+ assert rotary_dim <= headdim
124
+ cos_k = cos if cos_k is None else cos_k
125
+ sin_k = sin if sin_k is None else sin_k
126
+ assert sin.shape == cos_k.shape == sin_k.shape == (rotary_seqlen,
127
+ rotary_dim // 2)
128
+ q1, q2 = qkv[:, 0, :, :rotary_dim].chunk(2, dim=-1)
129
+ rotary_emb.apply_rotary(q1, q2, rearrange(cos, "s d -> s 1 d"),
130
+ rearrange(sin, "s d -> s 1 d"), q1, q2, False)
131
+ k1, k2 = qkv[:, 1, :, :rotary_dim].chunk(2, dim=-1)
132
+ rotary_emb.apply_rotary(k1, k2, rearrange(cos_k, "s d -> s 1 d"),
133
+ rearrange(sin_k, "s d -> s 1 d"), k1, k2,
134
+ False)
135
+ ctx.save_for_backward(cos, sin, cos_k, sin_k)
136
+ return qkv
137
+
138
+ @staticmethod
139
+ def backward(ctx, dqkv):
140
+ cos, sin, cos_k, sin_k = ctx.saved_tensors
141
+ rotary_dim = cos.shape[-1]
142
+ rotary_dim *= 2
143
+ dq1, dq2 = dqkv[:, 0, :, :rotary_dim].chunk(2, dim=-1)
144
+ rotary_emb.apply_rotary(dq1, dq2, rearrange(cos, "s d -> s 1 d"),
145
+ rearrange(sin, "s d -> s 1 d"), dq1, dq2, True)
146
+ dk1, dk2 = dqkv[:, 1, :, :rotary_dim].chunk(2, dim=-1)
147
+ rotary_emb.apply_rotary(dk1, dk2, rearrange(cos_k, "s d -> s 1 d"),
148
+ rearrange(sin_k, "s d -> s 1 d"), dk1, dk2,
149
+ True)
150
+ return dqkv, None, None, None, None
151
+
152
+
153
  class ConvertedInternLMRotaryEmbedding(torch.nn.Module):
154
  def __init__(self, dim: int, base=10000, scale_base=0, device=None):
155
  """ """
 
200
  self._cos_k_cached = (torch.cos(freqs) / scale).to(x.dtype)
201
  self._sin_k_cached = (torch.sin(freqs) / scale).to(x.dtype)
202
 
203
+ def forward(self,
204
+ qkv: torch.Tensor,
205
+ indexes=0) -> Tuple[torch.Tensor, torch.Tensor]:
206
+ self._update_cos_sin_cache(qkv, indexes)
207
+ if self.scale is None:
208
+ return apply_rotary_emb_qkv_(qkv, self._cos_cached[indexes],
209
+ self._sin_cached[indexes]).to(
210
+ qkv.dtype)
211
+ else:
212
+ return apply_rotary_emb_qkv_(
213
+ qkv,
214
+ self._cos_cached[indexes],
215
+ self._sin_cached[indexes],
216
+ self._cos_k_cached[indexes],
217
+ self._sin_k_cached[indexes],
218
+ ).to(qkv.dtype)
219
+
220
  def eval_forward(self, qkv, seqlen_offset=0):
221
  """
222
  seqlen_offset: can be used in generation where the qkv being passed in is only the last
 
237
  )
238
 
239
 
240
+ apply_rotary_emb_qkv_ = ApplyRotaryEmbQKV_.apply
241
  legacy_apply_rotary_embed_qkv = LegacyApplyRotaryEmbQKV_.apply
242
 
243
 
 
1322
  reordered_past = ()
1323
  for layer_past in past_key_values:
1324
  reordered_past += (tuple(
1325
+ past_state.index_select(0, beam_idx.to(past_state.device))
1326
  for past_state in layer_past), )
1327
  return reordered_past
modeling_InternLM_XComposer.py CHANGED
@@ -46,12 +46,13 @@ conversation
46
  def __init__(self, config):
47
  super().__init__(config)
48
 
49
- print('Init VIT ... ', end='')
 
50
  self.visual_encoder = create_eva_vit_g()
51
  self.ln_vision = LayerNorm(self.visual_encoder.num_features)
52
- print('Done')
53
 
54
- print('Init Perceive Sampler ... ', end='')
55
  with all_logging_disabled():
56
  self.Qformer, self.query_tokens = self.init_qformer(
57
  config.num_query_token, self.visual_encoder.num_features)
@@ -61,9 +62,9 @@ conversation
61
  layer.output = None
62
  layer.intermediate = None
63
  self.Qformer.cls = None
64
- print('Done')
65
 
66
- print('Init InternLM ... ', end='')
67
  self.flag_image_start = nn.Parameter(torch.zeros([1, 1, 4096]))
68
  self.flag_image_end = nn.Parameter(torch.zeros([1, 1, 4096]))
69
  self.flag_image_start.requires_grad = False
@@ -81,7 +82,7 @@ conversation
81
  # speed up init llm
82
  with torch.device('meta'):
83
  self.internlm_model = InternLMForCausalLM._from_config(config)
84
- self.internlm_model.to_empty(device='cpu').to(torch.float16)
85
  self.internlm_model.to(config.device)
86
  for n, m in self.internlm_model.named_modules():
87
  if 'lora' in n:
@@ -89,7 +90,7 @@ conversation
89
 
90
  self.internlm_proj = nn.Linear(self.Qformer.config.hidden_size,
91
  self.internlm_model.config.hidden_size)
92
- print('Done')
93
 
94
  self.vis_processor = transforms.Compose([
95
  transforms.Resize((224, 224),
@@ -111,6 +112,17 @@ conversation
111
  [StoppingCriteriaSub(stops=stop_words_ids)])
112
  self.gen_config['stopping_criteria'] = stopping_criteria
113
 
 
 
 
 
 
 
 
 
 
 
 
114
  def maybe_autocast(self, dtype=torch.float16):
115
  # if on cpu, don't use autocast
116
  # if on gpu, use autocast with dtype if provided, otherwise use torch.float16
@@ -268,3 +280,133 @@ conversation
268
  if history is not None:
269
  prompt_embeds = torch.cat([*history, prompt_embeds], dim=1)
270
  return prompt_embeds
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  def __init__(self, config):
47
  super().__init__(config)
48
 
49
+ self.max_length = config.max_length
50
+ rank0_print('Init VIT ... ', end='')
51
  self.visual_encoder = create_eva_vit_g()
52
  self.ln_vision = LayerNorm(self.visual_encoder.num_features)
53
+ rank0_print('Done')
54
 
55
+ rank0_print('Init Perceive Sampler ... ', end='')
56
  with all_logging_disabled():
57
  self.Qformer, self.query_tokens = self.init_qformer(
58
  config.num_query_token, self.visual_encoder.num_features)
 
62
  layer.output = None
63
  layer.intermediate = None
64
  self.Qformer.cls = None
65
+ rank0_print('Done')
66
 
67
+ rank0_print('Init InternLM ... ', end='')
68
  self.flag_image_start = nn.Parameter(torch.zeros([1, 1, 4096]))
69
  self.flag_image_end = nn.Parameter(torch.zeros([1, 1, 4096]))
70
  self.flag_image_start.requires_grad = False
 
82
  # speed up init llm
83
  with torch.device('meta'):
84
  self.internlm_model = InternLMForCausalLM._from_config(config)
85
+ self.internlm_model.to_empty(device=config.device).to(torch.float16)
86
  self.internlm_model.to(config.device)
87
  for n, m in self.internlm_model.named_modules():
88
  if 'lora' in n:
 
90
 
91
  self.internlm_proj = nn.Linear(self.Qformer.config.hidden_size,
92
  self.internlm_model.config.hidden_size)
93
+ rank0_print('Done')
94
 
95
  self.vis_processor = transforms.Compose([
96
  transforms.Resize((224, 224),
 
112
  [StoppingCriteriaSub(stops=stop_words_ids)])
113
  self.gen_config['stopping_criteria'] = stopping_criteria
114
 
115
+ self.supports_gradient_checkpointing = True
116
+
117
+ def get_input_embeddings(self):
118
+ return self.internlm_model.get_input_embeddings()
119
+
120
+ def _set_gradient_checkpointing(self, module, value=False):
121
+ if value:
122
+ self.internlm_model.apply(
123
+ partial(self.internlm_model._set_gradient_checkpointing,
124
+ value=True))
125
+
126
  def maybe_autocast(self, dtype=torch.float16):
127
  # if on cpu, don't use autocast
128
  # if on gpu, use autocast with dtype if provided, otherwise use torch.float16
 
280
  if history is not None:
281
  prompt_embeds = torch.cat([*history, prompt_embeds], dim=1)
282
  return prompt_embeds
283
+
284
+ ######################
285
+ # code for training
286
+ ######################
287
+ def prompt_wrap(self, img_embeds, prompt):
288
+ batch_size = img_embeds.shape[0]
289
+ p_before, p_after = prompt.split('<ImageHere>')
290
+ p_before_tokens = self.tokenizer(p_before,
291
+ return_tensors="pt",
292
+ add_special_tokens=True).to(
293
+ img_embeds.device)
294
+
295
+ p_before_embeds = self.internlm_model.model.embed_tokens(
296
+ p_before_tokens.input_ids).expand(batch_size, -1, -1)
297
+ wrapped_img_embeds = torch.cat([p_before_embeds, img_embeds], dim=1)
298
+
299
+ wrapped_atts_img = torch.ones(wrapped_img_embeds.size()[:-1],
300
+ dtype=torch.long).to(img_embeds.device)
301
+
302
+ wrapped_target = torch.ones(
303
+ batch_size, wrapped_img_embeds.shape[1], dtype=torch.long).to(
304
+ img_embeds.device) * -100
305
+
306
+ return wrapped_img_embeds, wrapped_atts_img, wrapped_target
307
+
308
+ def align_text(self, samples, has_img=False): ### add eos and eoa
309
+ text_new = []
310
+ if has_img: ### remove the first user to wrap image features
311
+ text = [
312
+ t.replace("<image>", "").split("<|User|>:", 1)[-1].lstrip()
313
+ for t in samples["text_input"]
314
+ ]
315
+ else:
316
+ text = [t for t in samples["text_input"]]
317
+
318
+ text = [t + self.eoa + ' </s>' for t in text]
319
+ for i in range(len(text)):
320
+ temp = text[i]
321
+ temp = temp.replace('<|Bot|>', self.eoh + ' <|Bot|>')
322
+ temp = temp.replace(' <|User|>', self.eoa + ' <|User|>')
323
+ if temp.find(self.eoh) > temp.find(self.eoa):
324
+ temp = temp.replace(self.eoa, '', 1)
325
+ text_new.append(temp)
326
+ return text_new
327
+
328
+ def text2emb(self, text):
329
+ to_regress_tokens = self.tokenizer(text,
330
+ return_tensors="pt",
331
+ padding="longest",
332
+ truncation=True,
333
+ max_length=self.max_length,
334
+ add_special_tokens=False).to(
335
+ self.device)
336
+
337
+ targets = self.mask_human_targets(to_regress_tokens.input_ids)
338
+ targets = targets.to(self.device)
339
+
340
+ return to_regress_tokens, targets
341
+
342
+ def mask_human_targets(self, input_ids, pure=False):
343
+ target_batch = []
344
+ for bs in range(input_ids.shape[0]):
345
+ cur_idx = 0
346
+ ids = input_ids[bs]
347
+ targets = copy.deepcopy(ids)
348
+ last_eoa = 0
349
+ last_eoh = 0
350
+ for i, temp_id in enumerate(ids):
351
+ if temp_id == 103027: #### end of human
352
+ targets[cur_idx:i + 6] = -100
353
+ cur_idx = i + 6
354
+ last_eoh = i
355
+ elif temp_id == 103028: ### end of assistant
356
+ cur_idx = i + 1
357
+ last_eoa = i
358
+ elif temp_id == 2: ### eos and following pad
359
+ targets[i + 1:] = -100 #### loss on eos, but not on pad
360
+ break
361
+ if temp_id != 2 and last_eoa > last_eoh: ### trunction, end at last question
362
+ targets[last_eoa +
363
+ 1:] = -100 #### mask all after the last answer
364
+
365
+ target_batch.append(targets.unsqueeze(0))
366
+
367
+ target_batch = torch.cat(target_batch, dim=0)
368
+ return target_batch
369
+
370
+ def forward(self,
371
+ input_ids=None,
372
+ attention_mask=None,
373
+ inputs_embeds=None,
374
+ labels=None,
375
+ output_attentions=None,
376
+ output_hidden_states=None,
377
+ return_dict=None,
378
+ **kwargs):
379
+
380
+ samples = kwargs.get('samples')
381
+ has_img = 'images' in samples.keys()
382
+
383
+ ### encode text
384
+ text = self.align_text(samples, has_img=has_img)
385
+ to_regress_tokens, targets = self.text2emb(text)
386
+
387
+ to_regress_embeds = self.internlm_model.model.embed_tokens(
388
+ to_regress_tokens.input_ids)
389
+ attention_mask = to_regress_tokens.attention_mask
390
+
391
+ if has_img:
392
+ header = samples["text_input"][0].split(' <|User|>:')[0]
393
+ prompt = header + ' <|User|>:<ImageHere>'
394
+
395
+ ### encode image
396
+ image = samples["image"]
397
+ img_embeds = self.encode_img(image)
398
+ img_embeds, atts_img, wrapped_target = self.prompt_wrap(
399
+ img_embeds, prompt)
400
+ ### combine text and image
401
+ to_regress_embeds = torch.cat([img_embeds, to_regress_embeds],
402
+ dim=1)
403
+ attention_mask = torch.cat([atts_img, attention_mask], dim=1)
404
+ targets = torch.cat([wrapped_target, targets], dim=1)
405
+
406
+ outputs = self.internlm_model(
407
+ inputs_embeds=to_regress_embeds,
408
+ attention_mask=attention_mask,
409
+ return_dict=True,
410
+ labels=targets,
411
+ )
412
+ return outputs
modeling_utils.py CHANGED
@@ -2,12 +2,12 @@ import logging
2
  import math
3
  import os
4
  from contextlib import contextmanager
5
- from transformers import StoppingCriteria, StoppingCriteriaList
6
 
7
  import timm.models.hub as timm_hub
8
  import torch
9
  import torch.distributed as dist
10
  import torch.nn as nn
 
11
 
12
 
13
  def is_dist_avail_and_initialized():
@@ -28,12 +28,16 @@ def is_main_process():
28
  return get_rank() == 0
29
 
30
 
 
 
 
 
 
31
  def download_cached_file(url, check_hash=True, progress=False):
32
  """
33
  Download a file from a URL and cache it locally. If the file already exists, it is not downloaded again.
34
  If distributed, only the main process downloads the file, and the other processes wait for the file to be downloaded.
35
  """
36
-
37
  def get_cached_file_path():
38
  # a hack to sync the file path across processes
39
  parts = torch.hub.urlparse(url)
@@ -76,18 +80,16 @@ def all_logging_disabled(highest_level=logging.CRITICAL):
76
 
77
 
78
  class LoRALinear(nn.Linear):
79
- def __init__(
80
- self,
81
- in_features: int,
82
- out_features: int,
83
- bias: bool = True,
84
- device=None,
85
- dtype=None,
86
- lora_r=8,
87
- lora_alpha=16,
88
- lora_dropout=0.05,
89
- **kwargs
90
- ) -> None:
91
  super().__init__(in_features, out_features, bias, device, dtype)
92
  self.lora_r = lora_r
93
  self.lora_alpha = lora_alpha
@@ -97,12 +99,16 @@ class LoRALinear(nn.Linear):
97
  self.lora_dropout = lambda x: x
98
  self.lora_scaling = self.lora_alpha / self.lora_r
99
 
100
- self.lora_A = nn.Linear(
101
- in_features, self.lora_r, bias=False, device=device, dtype=dtype
102
- )
103
- self.lora_B = nn.Linear(
104
- self.lora_r, out_features, bias=False, device=device, dtype=dtype
105
- )
 
 
 
 
106
 
107
  self.reset_parameters()
108
 
@@ -116,7 +122,8 @@ class LoRALinear(nn.Linear):
116
  orig_type = x.dtype
117
  res = super().forward(x)
118
  x = x.float()
119
- res += self.lora_B(self.lora_A(self.lora_dropout(x))) * self.lora_scaling
 
120
  return res.to(orig_type)
121
 
122
 
@@ -127,7 +134,7 @@ class StoppingCriteriaSub(StoppingCriteria):
127
 
128
  def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
129
  for stop in self.stops:
130
- if torch.all((stop == input_ids[:, -len(stop) :])).item():
131
  return True
132
 
133
  return False
 
2
  import math
3
  import os
4
  from contextlib import contextmanager
 
5
 
6
  import timm.models.hub as timm_hub
7
  import torch
8
  import torch.distributed as dist
9
  import torch.nn as nn
10
+ from transformers import StoppingCriteria, StoppingCriteriaList
11
 
12
 
13
  def is_dist_avail_and_initialized():
 
28
  return get_rank() == 0
29
 
30
 
31
+ def rank0_print(msg, **kwargs):
32
+ if is_main_process():
33
+ print(msg, **kwargs)
34
+
35
+
36
  def download_cached_file(url, check_hash=True, progress=False):
37
  """
38
  Download a file from a URL and cache it locally. If the file already exists, it is not downloaded again.
39
  If distributed, only the main process downloads the file, and the other processes wait for the file to be downloaded.
40
  """
 
41
  def get_cached_file_path():
42
  # a hack to sync the file path across processes
43
  parts = torch.hub.urlparse(url)
 
80
 
81
 
82
  class LoRALinear(nn.Linear):
83
+ def __init__(self,
84
+ in_features: int,
85
+ out_features: int,
86
+ bias: bool = True,
87
+ device=None,
88
+ dtype=None,
89
+ lora_r=8,
90
+ lora_alpha=16,
91
+ lora_dropout=0.05,
92
+ **kwargs) -> None:
 
 
93
  super().__init__(in_features, out_features, bias, device, dtype)
94
  self.lora_r = lora_r
95
  self.lora_alpha = lora_alpha
 
99
  self.lora_dropout = lambda x: x
100
  self.lora_scaling = self.lora_alpha / self.lora_r
101
 
102
+ self.lora_A = nn.Linear(in_features,
103
+ self.lora_r,
104
+ bias=False,
105
+ device=device,
106
+ dtype=dtype)
107
+ self.lora_B = nn.Linear(self.lora_r,
108
+ out_features,
109
+ bias=False,
110
+ device=device,
111
+ dtype=dtype)
112
 
113
  self.reset_parameters()
114
 
 
122
  orig_type = x.dtype
123
  res = super().forward(x)
124
  x = x.float()
125
+ res += self.lora_B(self.lora_A(
126
+ self.lora_dropout(x))) * self.lora_scaling
127
  return res.to(orig_type)
128
 
129
 
 
134
 
135
  def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
136
  for stop in self.stops:
137
+ if torch.all((stop == input_ids[:, -len(stop):])).item():
138
  return True
139
 
140
  return False