vilarin commited on
Commit
dac7d4a
1 Parent(s): c5b3aef

Delete model

Browse files
Files changed (1) hide show
  1. model/modeling_360vl.py +0 -813
model/modeling_360vl.py DELETED
@@ -1,813 +0,0 @@
1
- from typing import List, Optional, Tuple, Union
2
-
3
- import torch
4
- import torch.nn as nn
5
-
6
- from torch.nn import CrossEntropyLoss
7
-
8
- from transformers import AutoConfig, AutoModelForCausalLM, \
9
- LlamaConfig, LlamaModel, LlamaForCausalLM
10
-
11
- from transformers.modeling_outputs import CausalLMOutputWithPast
12
-
13
- from PIL import Image
14
-
15
- from abc import ABC, abstractmethod
16
- import os
17
-
18
- import math
19
- from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig
20
- from functools import partial
21
- from transformers.configuration_utils import PretrainedConfig
22
-
23
- from timm.models.layers import LayerNorm, LayerNorm2d
24
- from timm.models.regnet import RegStage
25
- from torch.nn import functional as F
26
- import math
27
- from einops import rearrange
28
-
29
-
30
-
31
- CONTROLLER_HEART_BEAT_EXPIRATION = 30
32
- WORKER_HEART_BEAT_INTERVAL = 15
33
-
34
- #LOGDIR = "."
35
- black_domains = ["device-api.zero", "checkip.amazonaws.com"]
36
-
37
-
38
- # Model Constants
39
- IGNORE_INDEX = -100
40
- IMAGE_TOKEN_INDEX = -200
41
- DEFAULT_IMAGE_TOKEN = "<image>"
42
- DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
43
- DEFAULT_IM_START_TOKEN = "<im_start>"
44
- DEFAULT_IM_END_TOKEN = "<im_end>"
45
-
46
-
47
-
48
-
49
-
50
- class CLIPVisionTower(nn.Module):
51
- def __init__(self, vision_tower, args, delay_load=False):
52
- super().__init__()
53
-
54
- self.is_loaded = False
55
-
56
- self.vision_tower_name = vision_tower
57
- self.select_layer = args.mm_vision_select_layer
58
- self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch')
59
-
60
- if not delay_load:
61
- self.load_model()
62
- else:
63
- self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name)
64
-
65
- def load_model(self):
66
- self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name)
67
- self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name)
68
- self.vision_tower.requires_grad_(False)
69
-
70
- self.is_loaded = True
71
-
72
- def feature_select(self, image_forward_outs):
73
- image_features = image_forward_outs.hidden_states[self.select_layer]
74
- if self.select_feature == 'patch':
75
- image_features = image_features[:, 1:]
76
- elif self.select_feature == 'cls_patch':
77
- image_features = image_features
78
- else:
79
- raise ValueError(f'Unexpected select feature: {self.select_feature}')
80
- return image_features
81
-
82
- @torch.no_grad()
83
- def forward(self, images):
84
- if type(images) is list:
85
- image_features = []
86
- for image in images:
87
- image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True)
88
- image_feature = self.feature_select(image_forward_out).to(image.dtype)
89
- image_features.append(image_feature)
90
- else:
91
- image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True)
92
- image_features = self.feature_select(image_forward_outs).to(images.dtype)
93
-
94
- return image_features
95
-
96
- @property
97
- def dummy_feature(self):
98
- return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
99
-
100
- @property
101
- def dtype(self):
102
- return self.vision_tower.dtype
103
-
104
- @property
105
- def device(self):
106
- return self.vision_tower.device
107
-
108
- @property
109
- def config(self):
110
- if self.is_loaded:
111
- return self.vision_tower.config
112
- else:
113
- return self.cfg_only
114
-
115
- @property
116
- def hidden_size(self):
117
- return self.config.hidden_size
118
-
119
- @property
120
- def num_patches(self):
121
- return (self.config.image_size // self.config.patch_size) ** 2
122
-
123
-
124
- def build_vision_tower(vision_tower_cfg, **kwargs):
125
- vision_tower = getattr(vision_tower_cfg, 'mm_vision_tower', getattr(vision_tower_cfg, 'vision_tower', None))
126
- is_absolute_path_exists = os.path.exists(vision_tower)
127
-
128
- if is_absolute_path_exists or vision_tower.startswith("openai") or vision_tower.startswith("laion"):
129
- return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
130
-
131
- raise ValueError(f'Unknown vision tower: {vision_tower}')
132
-
133
-
134
-
135
-
136
-
137
- class HoneybeeVisualProjectorConfig(PretrainedConfig):
138
- model_type = "mllm_visual_projector"
139
-
140
- def __init__(
141
- self,
142
- projector_type: str = "resampler",
143
- hidden_size: int = 1024, #
144
- num_hidden_layers: int = 6, #
145
- num_attention_heads: int = 16, #
146
- intermediate_size: int = 4096, #
147
- attention_probs_dropout_prob: float = 0.1, #
148
- initializer_range: float = 0.02,
149
- layer_norm_eps: float = 1e-6, #
150
- encoder_hidden_size: int = 1024, # This will be overwritten by vision_model's hidden_size
151
- pos_emb=False,
152
- feature_layer_index=-1, # vision feature layer index; -1: last layer
153
- num_eos_tokens=1,
154
- use_cls=True,
155
- prenorm=False,
156
- **kwargs,
157
- ):
158
- super().__init__(**kwargs)
159
- self.projector_type = projector_type
160
- self.hidden_size = hidden_size
161
- self.num_hidden_layers = num_hidden_layers
162
- self.num_attention_heads = num_attention_heads
163
- self.intermediate_size = intermediate_size
164
- self.attention_probs_dropout_prob = attention_probs_dropout_prob
165
- self.initializer_range = initializer_range
166
- self.layer_norm_eps = layer_norm_eps
167
- self.encoder_hidden_size = encoder_hidden_size
168
-
169
- self.pos_emb = pos_emb
170
- self.feature_layer_index = feature_layer_index
171
- self.num_eos_tokens = num_eos_tokens
172
- self.use_cls = use_cls
173
- self.prenorm = prenorm
174
-
175
- @classmethod
176
- def from_pretrained(
177
- cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
178
- ) -> "PretrainedConfig":
179
- config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
180
-
181
- # get the visual_projector config dict if we are loading from HoneybeeConfig
182
- if config_dict.get("model_type") == "QH_360VL":
183
- config_dict = config_dict["visual_projector_config"]
184
-
185
- '''
186
- if (
187
- "model_type" in config_dict
188
- and hasattr(cls, "model_type")
189
- and config_dict["model_type"] != cls.model_type
190
- ):
191
- logger.warning(
192
- f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
193
- f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
194
- )
195
- '''
196
-
197
- return cls.from_dict(config_dict, **kwargs)
198
-
199
- def build_pos_embeds(
200
- config: HoneybeeVisualProjectorConfig, num_input_tokens: int, vision_hidden_size: int
201
- ):
202
- # pos emb
203
- # true
204
- if config.pos_emb:
205
- pos_emb = torch.nn.Parameter(torch.zeros(1, num_input_tokens, vision_hidden_size))
206
- nn.init.trunc_normal_(pos_emb, mean=0.0, std=0.02)
207
- else:
208
- pos_emb = None
209
-
210
- return pos_emb
211
-
212
-
213
- def build_eos_tokens(config: HoneybeeVisualProjectorConfig, output_hidden_size: int):
214
- # think tokens
215
- num_eos_tokens = config.num_eos_tokens
216
- # 0
217
- if num_eos_tokens:
218
- eos_tokens = torch.nn.Parameter(torch.randn(1, num_eos_tokens, output_hidden_size))
219
- nn.init.trunc_normal_(eos_tokens, mean=0.0, std=config.initializer_range)
220
- else:
221
- eos_tokens = None
222
-
223
- return eos_tokens
224
-
225
-
226
- def build_prenorm(config: HoneybeeVisualProjectorConfig):
227
- # false
228
- if config.prenorm:
229
- prenorm = LayerNorm(config.encoder_hidden_size)
230
- else:
231
- prenorm = None
232
- return prenorm
233
-
234
-
235
- def build_mlp(depth, hidden_size, output_hidden_size):
236
- layers = [nn.Linear(hidden_size, output_hidden_size)]
237
- for _ in range(1, depth):
238
- layers.append(nn.SiLU())
239
- layers.append(nn.Linear(output_hidden_size, output_hidden_size))
240
- return nn.Sequential(*layers)
241
-
242
- def get_abs_pos(abs_pos, tgt_size):
243
- # abs_pos: L, C
244
- # tgt_size: M
245
- # return: M, C
246
- # 16,24
247
- src_size = int(math.sqrt(abs_pos.size(1)))
248
- # 32,48
249
- tgt_size = int(math.sqrt(tgt_size))
250
- dtype = abs_pos.dtype
251
-
252
- if src_size != tgt_size:
253
- return F.interpolate(
254
- abs_pos.float().reshape(1, src_size, src_size, -1).permute(0, 3, 1, 2),
255
- size=(tgt_size, tgt_size),
256
- mode="bicubic",
257
- align_corners=False,
258
- ).permute(0, 2, 3, 1).flatten(0, 2).to(dtype=dtype)
259
- else:
260
- return abs_pos
261
-
262
-
263
- class Projector(nn.Module):
264
- """Base projector class"""
265
-
266
- def __init__(
267
- self,
268
- config: HoneybeeVisualProjectorConfig,
269
- num_input_tokens: int,
270
- output_hidden_size: int,
271
- ):
272
- super().__init__()
273
- self.config = config
274
- self.num_input_tokens = num_input_tokens
275
- self.output_hidden_size = output_hidden_size
276
-
277
- # think tokens
278
- self.eos_tokens = build_eos_tokens(config, output_hidden_size)
279
-
280
- # pos emb
281
- self.pos_emb = build_pos_embeds(config, num_input_tokens, config.encoder_hidden_size)
282
-
283
- self.prenorm = build_prenorm(config)
284
-
285
- self.build_net()
286
-
287
- def build_net(self):
288
- raise NotImplementedError()
289
-
290
- def _forward(self, x):
291
- raise NotImplementedError()
292
-
293
- def forward(self, x: torch.Tensor) -> torch.Tensor:
294
- """
295
- Args:
296
- x: (B, L, encoder_hidden_size) tensor from the visual backbone (CLIP visual encoder), including cls token.
297
- """
298
- if self.prenorm is not None:
299
- x = self.prenorm(x)
300
-
301
- if self.pos_emb is not None:
302
- # self.pos_emb = self.pos_emb[:,1:]
303
- pos_emb = get_abs_pos(self.pos_emb[:,1:], x.size(1))
304
- pos_emb = pos_emb.to(device=x.device)
305
- x += pos_emb
306
-
307
- x = self._forward(x) # (B, L, output_hidden_size)
308
-
309
- B = x.size(0)
310
- if self.eos_tokens is not None:
311
- x = torch.cat([x, self.eos_tokens.expand(B, -1, -1)], dim=1)
312
- return x
313
-
314
-
315
- class ConvProjector(Projector):
316
- def _forward(self, x):
317
- # x: [B, L, dim]
318
- # x = x[:, 1:] # drop cls token and 2d forward
319
-
320
- hw = int(x.size(1) ** 0.5)
321
- x = rearrange(x, "b (h w) d -> b d h w", h=hw, w=hw)
322
- x = self.net(x)
323
- x = rearrange(x, "b d h w -> b (h w) d")
324
- x = self.readout(x)
325
-
326
- return x
327
-
328
-
329
- class CAbstractor(ConvProjector):
330
- """C-Abstractor"""
331
- def build_net(self):
332
- encoder_hidden_size = self.config.encoder_hidden_size
333
- hidden_size = self.config.hidden_size
334
- output_hidden_size = self.output_hidden_size
335
- depth = self.config.depth
336
- mlp_depth = self.config.mlp_depth
337
-
338
- n_queries = self.config.num_queries
339
- assert (n_queries ** 0.5).is_integer(), "n_queries must be square number"
340
- hw = int(n_queries ** 0.5)
341
-
342
- # RegBlock = ResBlock + SE
343
- RegBlock = partial(
344
- RegStage,
345
- stride=1,
346
- dilation=1,
347
- act_layer=nn.SiLU,
348
- norm_layer=LayerNorm2d,
349
- )
350
-
351
- s1 = RegBlock(
352
- depth,
353
- encoder_hidden_size,
354
- hidden_size,
355
- )
356
- sampler = nn.AdaptiveAvgPool2d((hw, hw))
357
- s2 = RegBlock(
358
- depth,
359
- hidden_size,
360
- hidden_size,
361
- )
362
-
363
- self.net = nn.Sequential(s1, sampler, s2)
364
- self.readout = build_mlp(mlp_depth, hidden_size, output_hidden_size)
365
-
366
- class IdentityMap(nn.Module):
367
- def __init__(self):
368
- super().__init__()
369
-
370
- def forward(self, x, *args, **kwargs):
371
- return x
372
-
373
- @property
374
- def config(self):
375
- return {"mm_projector_type": 'identity'}
376
-
377
-
378
- class SimpleResBlock(nn.Module):
379
- def __init__(self, channels):
380
- super().__init__()
381
- self.pre_norm = nn.LayerNorm(channels)
382
-
383
- self.proj = nn.Sequential(
384
- nn.Linear(channels, channels),
385
- nn.GELU(),
386
- nn.Linear(channels, channels)
387
- )
388
- def forward(self, x):
389
- x = self.pre_norm(x)
390
- return x + self.proj(x)
391
-
392
-
393
- def build_honeybee_projector(config, projector_type, num_tokens,lm_hidden_size):
394
- """Build projector (abstractor) and query_tokens (optionally for resampler)"""
395
- proj_config = config
396
- proj_type = projector_type
397
- num_tokens = num_tokens
398
- output_hidden_size = lm_hidden_size # LM hidden size
399
-
400
- abstractor = {
401
- "c-abs": CAbstractor,
402
- }[
403
- proj_type
404
- ](proj_config, num_tokens, output_hidden_size)
405
- return abstractor
406
-
407
-
408
- def build_vision_projector(config, delay_load=False, **kwargs):
409
- projector_type = getattr(config, 'mm_projector_type', 'linear')
410
-
411
- if projector_type == 'linear':
412
- return nn.Linear(config.mm_hidden_size, config.hidden_size)
413
-
414
- if projector_type == 'c-abs':
415
-
416
- local_config_path = config.mm_projector_config
417
- honeybee_config = HoneybeeVisualProjectorConfig.from_pretrained(local_config_path)
418
-
419
- num_tokens = config.mm_num_tokens
420
-
421
- lm_hidden_size = config.hidden_size
422
-
423
- abstractor = build_honeybee_projector(honeybee_config,projector_type,num_tokens,lm_hidden_size)
424
- return abstractor
425
-
426
- mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type)
427
- if mlp_gelu_match:
428
- mlp_depth = int(mlp_gelu_match.group(1))
429
- modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)]
430
- for _ in range(1, mlp_depth):
431
- modules.append(nn.GELU())
432
- modules.append(nn.Linear(config.hidden_size, config.hidden_size))
433
- return nn.Sequential(*modules)
434
-
435
- if projector_type == 'identity':
436
- return IdentityMap()
437
-
438
- raise ValueError(f'Unknown projector type: {projector_type}')
439
-
440
-
441
-
442
-
443
- class QH360_VL_MetaModel:
444
-
445
- def __init__(self, config):
446
- super(QH360_VL_MetaModel, self).__init__(config)
447
- if hasattr(config, "mm_vision_tower"):
448
- self.vision_tower = build_vision_tower(config, delay_load=True)
449
- self.mm_projector_ctt = build_vision_projector(config)
450
- self.mm_projector_ori = build_vision_projector(config)
451
-
452
-
453
-
454
- def get_vision_tower(self):
455
- vision_tower = getattr(self, 'vision_tower', None)
456
- if type(vision_tower) is list:
457
- vision_tower = vision_tower[0]
458
- return vision_tower
459
-
460
-
461
- class QH360_VL_MetaForCausalLM(ABC):
462
-
463
- @abstractmethod
464
- def get_model(self):
465
- pass
466
-
467
- def get_vision_tower(self):
468
- return self.get_model().get_vision_tower()
469
-
470
- def encode_images(self, images):
471
- image_features = self.get_model().get_vision_tower()(images)
472
- image_features = self.get_model().mm_projector(image_features)
473
- return image_features
474
-
475
- def encode_images_noprojector(self, images):
476
- image_features = self.get_model().get_vision_tower()(images)
477
- image_features = image_features.detach()
478
- return image_features
479
-
480
- def prepare_inputs_labels_for_multimodal(
481
- self, input_ids, attention_mask, past_key_values, labels, images
482
- ):
483
- vision_tower = self.get_vision_tower()
484
- if vision_tower is None or images is None or input_ids.shape[1] == 1:
485
- if past_key_values is not None and vision_tower is not None and images is not None and input_ids.shape[1] == 1:
486
- attention_mask = torch.ones((attention_mask.shape[0], past_key_values[-1][-1].shape[-2] + 1), dtype=attention_mask.dtype, device=attention_mask.device)
487
- return input_ids, attention_mask, past_key_values, None, labels
488
-
489
- if type(images) is list or images.ndim == 5:
490
- image_features = []
491
- for image in images:
492
- if image.ndim == 3:
493
- image_features.append(self.encode_images(image.unsqueeze(0)).squeeze(0))
494
- elif image.ndim == 4:
495
- #NOTE cc-plan
496
- temp_feats = self.encode_images_noprojector(image)
497
- src_size = int(math.sqrt(temp_feats.shape[1]))
498
- temp_feats = temp_feats.reshape(temp_feats.shape[0]//5,5,-1, temp_feats.shape[-1])
499
- x1 = temp_feats[:,4,:,:]
500
- x = temp_feats[:,:4,:,:]
501
- x = x.reshape(x.shape[0], -1, src_size, src_size, x.shape[-1])
502
- x = x.transpose(1,2).reshape(x.shape[0], src_size,2,2, src_size, x.shape[-1])
503
- x = x.transpose(1,2).reshape(x.shape[0], -1, x.shape[-1])
504
- x1 = self.get_model().mm_projector_ori(x1).squeeze(0)
505
- x = self.get_model().mm_projector_ctt(x).squeeze(0)
506
- temp_feats_all = torch.cat([x,x1],dim=0)
507
- image_features.append(temp_feats_all)
508
- else:
509
- image_features = self.encode_images(images)
510
-
511
-
512
- new_input_embeds = []
513
- new_labels = [] if labels is not None else None
514
- cur_image_idx = 0
515
- for batch_idx, cur_input_ids in enumerate(input_ids):
516
- if (cur_input_ids == IMAGE_TOKEN_INDEX).sum() == 0:
517
- # multimodal LLM, but the current sample is not multimodal
518
- # FIXME: this is a hacky fix, for deepspeed zero3 to work
519
- half_len = cur_input_ids.shape[0] // 2
520
- cur_image_features = image_features[cur_image_idx]
521
- cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids[:half_len])
522
- cur_input_embeds_2 = self.get_model().embed_tokens(cur_input_ids[half_len:])
523
- cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0], cur_input_embeds_2], dim=0)
524
- new_input_embeds.append(cur_input_embeds)
525
- if labels is not None:
526
- new_labels.append(labels[batch_idx])
527
- cur_image_idx += 1
528
- continue
529
- image_token_indices = torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0]
530
- cur_new_input_embeds = []
531
- if labels is not None:
532
- cur_labels = labels[batch_idx]
533
- cur_new_labels = []
534
- assert cur_labels.shape == cur_input_ids.shape
535
- while image_token_indices.numel() > 0:
536
- cur_image_features = image_features[cur_image_idx]
537
- image_token_start = image_token_indices[0]
538
- if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False):
539
- cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[:image_token_start-1]).detach())
540
- cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[image_token_start-1:image_token_start]))
541
- cur_new_input_embeds.append(cur_image_features)
542
- cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[image_token_start+1:image_token_start+2]))
543
- if labels is not None:
544
- cur_new_labels.append(cur_labels[:image_token_start])
545
- cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=labels.device, dtype=labels.dtype))
546
- cur_new_labels.append(cur_labels[image_token_start:image_token_start+1])
547
- cur_labels = cur_labels[image_token_start+2:]
548
- else:
549
- cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[:image_token_start]))
550
- cur_new_input_embeds.append(cur_image_features)
551
- if labels is not None:
552
- cur_new_labels.append(cur_labels[:image_token_start])
553
- cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=labels.device, dtype=labels.dtype))
554
- cur_labels = cur_labels[image_token_start+1:]
555
- cur_image_idx += 1
556
- if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False):
557
- cur_input_ids = cur_input_ids[image_token_start+2:]
558
- else:
559
- cur_input_ids = cur_input_ids[image_token_start+1:]
560
- image_token_indices = torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0]
561
- if cur_input_ids.numel() > 0:
562
- if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False):
563
- cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids).detach())
564
- else:
565
- cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids))
566
- if labels is not None:
567
- cur_new_labels.append(cur_labels)
568
- cur_new_input_embeds = [x.to(device=self.device) for x in cur_new_input_embeds]
569
- cur_new_input_embeds = torch.cat(cur_new_input_embeds, dim=0)
570
- new_input_embeds.append(cur_new_input_embeds)
571
- if labels is not None:
572
- cur_new_labels = torch.cat(cur_new_labels, dim=0)
573
- new_labels.append(cur_new_labels)
574
-
575
- if any(x.shape != new_input_embeds[0].shape for x in new_input_embeds):
576
- max_len = max(x.shape[0] for x in new_input_embeds)
577
-
578
- new_input_embeds_align = []
579
- for cur_new_embed in new_input_embeds:
580
- cur_new_embed = torch.cat((cur_new_embed, torch.zeros((max_len - cur_new_embed.shape[0], cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)), dim=0)
581
- new_input_embeds_align.append(cur_new_embed)
582
- new_input_embeds = torch.stack(new_input_embeds_align, dim=0)
583
-
584
- if labels is not None:
585
- new_labels_align = []
586
- _new_labels = new_labels
587
- for cur_new_label in new_labels:
588
- cur_new_label = torch.cat((cur_new_label, torch.full((max_len - cur_new_label.shape[0],), IGNORE_INDEX, dtype=cur_new_label.dtype, device=cur_new_label.device)), dim=0)
589
- new_labels_align.append(cur_new_label)
590
- new_labels = torch.stack(new_labels_align, dim=0)
591
-
592
- if attention_mask is not None:
593
- new_attention_mask = []
594
- for cur_attention_mask, cur_new_labels, cur_new_labels_align in zip(attention_mask, _new_labels, new_labels):
595
- new_attn_mask_pad_left = torch.full((cur_new_labels.shape[0] - labels.shape[1],), True, dtype=attention_mask.dtype, device=attention_mask.device)
596
- new_attn_mask_pad_right = torch.full((cur_new_labels_align.shape[0] - cur_new_labels.shape[0],), False, dtype=attention_mask.dtype, device=attention_mask.device)
597
- cur_new_attention_mask = torch.cat((new_attn_mask_pad_left, cur_attention_mask, new_attn_mask_pad_right), dim=0)
598
- new_attention_mask.append(cur_new_attention_mask)
599
- attention_mask = torch.stack(new_attention_mask, dim=0)
600
- assert attention_mask.shape == new_labels.shape
601
- else:
602
- new_input_embeds = torch.stack(new_input_embeds, dim=0)
603
- if labels is not None:
604
- new_labels = torch.stack(new_labels, dim=0)
605
-
606
- if attention_mask is not None:
607
- new_attn_mask_pad_left = torch.full((attention_mask.shape[0], new_input_embeds.shape[1] - input_ids.shape[1]), True, dtype=attention_mask.dtype, device=attention_mask.device)
608
- attention_mask = torch.cat((new_attn_mask_pad_left, attention_mask), dim=1)
609
- assert attention_mask.shape == new_input_embeds.shape[:2]
610
-
611
- return None, attention_mask, past_key_values, new_input_embeds, new_labels
612
-
613
-
614
-
615
- class QH360_VLConfig(LlamaConfig):
616
- model_type = "QH_360VL"
617
-
618
-
619
- class QH360_VL_LlamaModel(QH360_VL_MetaModel, LlamaModel):
620
- config_class = QH360_VLConfig
621
-
622
- def __init__(self, config: LlamaConfig):
623
- super(QH360_VL_LlamaModel, self).__init__(config)
624
-
625
-
626
- class QH360_VL_LlamaForCausalLM(LlamaForCausalLM, QH360_VL_MetaForCausalLM):
627
- config_class = QH360_VLConfig
628
-
629
- def __init__(self, config):
630
- super(LlamaForCausalLM, self).__init__(config)
631
- config._attn_implementation == "flash_attention_2"
632
- self.model = QH360_VL_LlamaModel(config)
633
-
634
- self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
635
-
636
- # Initialize weights and apply final processing
637
- self.post_init()
638
-
639
- def get_model(self):
640
- return self.model
641
-
642
- def forward(
643
- self,
644
- input_ids: torch.LongTensor = None,
645
- attention_mask: Optional[torch.Tensor] = None,
646
- past_key_values: Optional[List[torch.FloatTensor]] = None,
647
- inputs_embeds: Optional[torch.FloatTensor] = None,
648
- labels: Optional[torch.LongTensor] = None,
649
- use_cache: Optional[bool] = None,
650
- output_attentions: Optional[bool] = None,
651
- output_hidden_states: Optional[bool] = None,
652
- images: Optional[torch.FloatTensor] = None,
653
- return_dict: Optional[bool] = None,
654
- ) -> Union[Tuple, CausalLMOutputWithPast]:
655
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
656
- output_hidden_states = (
657
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
658
- )
659
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
660
-
661
- input_ids, attention_mask, past_key_values, inputs_embeds, labels = self.prepare_inputs_labels_for_multimodal(input_ids, attention_mask, past_key_values, labels, images)
662
-
663
- # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
664
- outputs = self.model(
665
- input_ids=input_ids,
666
- attention_mask=attention_mask,
667
- past_key_values=past_key_values,
668
- inputs_embeds=inputs_embeds,
669
- use_cache=use_cache,
670
- output_attentions=output_attentions,
671
- output_hidden_states=output_hidden_states,
672
- return_dict=return_dict
673
- )
674
-
675
- hidden_states = outputs[0]
676
- logits = self.lm_head(hidden_states)
677
-
678
- loss = None
679
- if labels is not None:
680
- # Shift so that tokens < n predict n
681
- shift_logits = logits[..., :-1, :].contiguous()
682
- shift_labels = labels[..., 1:].contiguous()
683
- # Flatten the tokens
684
- loss_fct = CrossEntropyLoss()
685
- shift_logits = shift_logits.view(-1, self.config.vocab_size)
686
- shift_labels = shift_labels.view(-1)
687
- # Enable model/pipeline parallelism
688
- shift_labels = shift_labels.to(shift_logits.device)
689
- loss = loss_fct(shift_logits, shift_labels)
690
-
691
- if not return_dict:
692
- output = (logits,) + outputs[1:]
693
- return (loss,) + output if loss is not None else output
694
-
695
- return CausalLMOutputWithPast(
696
- loss=loss,
697
- logits=logits,
698
- past_key_values=outputs.past_key_values,
699
- hidden_states=outputs.hidden_states,
700
- attentions=outputs.attentions,
701
- )
702
-
703
- def prepare_inputs_for_generation(
704
- self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
705
- ):
706
- if past_key_values:
707
- input_ids = input_ids[:, -1:]
708
-
709
- # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
710
- if inputs_embeds is not None and past_key_values is None:
711
- model_inputs = {"inputs_embeds": inputs_embeds}
712
- else:
713
- model_inputs = {"input_ids": input_ids}
714
-
715
- model_inputs.update(
716
- {
717
- "past_key_values": past_key_values,
718
- "use_cache": kwargs.get("use_cache"),
719
- "attention_mask": attention_mask,
720
- "images": kwargs.get("images", None),
721
- }
722
- )
723
- return model_inputs
724
-
725
- def build_conversation_input_ids(
726
- self,
727
- tokenizer: "PreTrainedTokenizer",
728
- query: str,
729
- image = None,
730
- image_processor=None,
731
- ):
732
-
733
- input_msg = [
734
- {
735
- "role": "system",
736
- "content": "You are a multilingual, helpful, respectful and honest assistant who can respond in the same language, depending on the language of the question. Try to be as helpful as possible while still being safe. Your answer should not contain anything that is false, unhealthy, harmful, immoral, racist, sexist, toxic, dangerous, or illegal, and if the question relates to such content, please decline to answer. Make sure your answer is socially fair and positive. If a question doesn't make any sense, or is inconsistent with the facts, explain why instead of answering the wrong answer. If you don't know the answer to a question, don't share false information."
737
- },
738
- {
739
- "role": "user",
740
- "content": "<|reserved_special_token_44|>"+ '\n' + query
741
- }
742
- ]
743
-
744
- input_ids = tokenizer.apply_chat_template(
745
- input_msg,
746
- add_generation_prompt=True,
747
- padding="longest",
748
- return_tensors="pt",
749
- )
750
- input_id_list = input_ids[0].tolist()
751
- input_id_list[input_id_list.index(128049)]=-200
752
- input_ids = torch.tensor(input_id_list, dtype=input_ids.dtype,device=input_ids.device)
753
- input_ids = input_ids.unsqueeze(0)
754
- image_tensor = self.process_images_slid_window(image,image_processor).unsqueeze(0)
755
-
756
- return {
757
- 'input_ids': input_ids,
758
- 'image': image_tensor,
759
- }
760
-
761
-
762
-
763
- def process_images_slid_window(self, image, image_processor, vit_is=336):
764
-
765
- def get_proper_imgsize(pil_img, vit_is):
766
- max_w_h = vit_is * 2
767
- new_pil_img = pil_img.resize((max_w_h, max_w_h))
768
- return new_pil_img
769
-
770
- def tensor_crop(tensor_array, left, upper, right, lower):
771
- # tensor_array: C * H * W
772
- return tensor_array[:, upper:lower, left:right]
773
-
774
- def image_slid_window(image, num_slid_window):
775
- # image: tensor, 3 * 336 * 336 or 3 * 672 * 672
776
- # image: tensor, 3 * 224 * 224 or 3 * 448 * 448
777
- if num_slid_window == 5:
778
- image_x2, image_x1 = image[0], image[1]
779
- vit_is = image_x1.shape[1]
780
- h, w = image_x2.shape[1],image_x2.shape[2]
781
- image0 = tensor_crop(image_x2, 0, 0, vit_is, vit_is)
782
- image1 = tensor_crop(image_x2, w-vit_is, 0, w, vit_is)
783
- image2 = tensor_crop(image_x2, 0, h-vit_is, vit_is, h)
784
- image3 = tensor_crop(image_x2, w-vit_is, h-vit_is, w, h)
785
- return torch.stack([image0, image1, image2, image3, image_x1])
786
- else:
787
- return image
788
-
789
- def expand2square(pil_img, background_color):
790
- width, height = pil_img.size
791
- if width == height:
792
- return pil_img
793
- elif width > height:
794
- result = Image.new(pil_img.mode, (width, width), background_color)
795
- result.paste(pil_img, (0, (width - height) // 2))
796
- return result
797
- else:
798
- result = Image.new(pil_img.mode, (height, height), background_color)
799
- result.paste(pil_img, ((height - width) // 2, 0))
800
- return result
801
-
802
- vit_is = vit_is # vit_input_size, for simplicity
803
-
804
- num_slid_window = 5
805
-
806
- image = expand2square(image, tuple(int(x*255) for x in image_processor.image_mean))
807
- image = get_proper_imgsize(image, vit_is)
808
- image_x2 = image_processor.preprocess(image, return_tensors='pt', do_resize=False, do_center_crop=False)['pixel_values'][0]
809
- image_x1 = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
810
- image = [image_x2, image_x1]
811
- image = image_slid_window(image, num_slid_window)
812
-
813
- return image