visheratin commited on
Commit
4830517
1 Parent(s): 2d1df02

Upload folder using huggingface_hub

Browse files
config.json CHANGED
@@ -1,24 +1,24 @@
1
  {
2
- "architectures": [
3
- "LlavaForConditionalGeneration"
4
- ],
 
 
 
5
  "ignore_index": -100,
6
  "image_token_index": 50297,
7
- "max_image_tokens": 100,
8
- "model_type": "llava",
9
  "projector_hidden_act": "gelu",
10
- "projector_tokens_num": 5,
11
  "text_config": {
12
- "_name_or_path": "cognitivecomputations/dolphin-2_6-phi-2",
13
- "activation_function": "gelu_new",
14
  "add_cross_attention": false,
15
  "architectures": [
16
  "PhiForCausalLM"
17
  ],
18
- "attn_pdrop": 0.0,
19
  "auto_map": {
20
- "AutoConfig": "cognitivecomputations/dolphin-2_6-phi-2--configuration_phi.PhiConfig",
21
- "AutoModelForCausalLM": "cognitivecomputations/dolphin-2_6-phi-2--modeling_phi.PhiForCausalLM"
22
  },
23
  "bad_words_ids": null,
24
  "begin_suppress_tokens": null,
@@ -34,51 +34,51 @@
34
  "eos_token_id": null,
35
  "exponential_decay_length_penalty": null,
36
  "finetuning_task": null,
37
- "flash_attn": false,
38
- "flash_rotary": false,
39
  "forced_bos_token_id": null,
40
  "forced_eos_token_id": null,
41
- "fused_dense": false,
 
42
  "id2label": {
43
  "0": "LABEL_0",
44
  "1": "LABEL_1"
45
  },
46
- "img_processor": null,
47
  "initializer_range": 0.02,
 
48
  "is_decoder": false,
49
  "is_encoder_decoder": false,
50
  "label2id": {
51
  "LABEL_0": 0,
52
  "LABEL_1": 1
53
  },
54
- "layer_norm_epsilon": 1e-05,
55
  "length_penalty": 1.0,
56
  "max_length": 20,
 
57
  "min_length": 0,
58
- "model_type": "phi-msft",
59
- "n_embd": 2560,
60
- "n_head": 32,
61
- "n_head_kv": null,
62
- "n_inner": null,
63
- "n_layer": 32,
64
- "n_positions": 2048,
65
  "no_repeat_ngram_size": 0,
 
66
  "num_beam_groups": 1,
67
  "num_beams": 1,
 
 
68
  "num_return_sequences": 1,
69
  "output_attentions": false,
70
  "output_hidden_states": false,
71
  "output_scores": false,
72
  "pad_token_id": null,
 
73
  "prefix": null,
74
  "problem_type": null,
75
  "pruned_heads": {},
 
76
  "remove_invalid_values": false,
77
  "repetition_penalty": 1.0,
78
  "resid_pdrop": 0.1,
79
  "return_dict": true,
80
  "return_dict_in_generate": false,
81
- "rotary_dim": 32,
 
82
  "sep_token_id": null,
83
  "suppress_tokens": null,
84
  "task_specific_params": null,
@@ -89,31 +89,25 @@
89
  "tokenizer_class": null,
90
  "top_k": 50,
91
  "top_p": 1.0,
92
- "torch_dtype": "float16",
93
  "torchscript": false,
94
  "typical_p": 1.0,
95
  "use_bfloat16": false,
96
- "use_cache": false,
97
  "vocab_size": 51200
98
  },
99
- "preprocess_config": {
100
- "mean": [
101
- 0.5,
102
- 0.5,
103
- 0.5
104
- ],
105
- "std": [
106
- 0.5,
107
- 0.5,
108
- 0.5
109
- ],
110
- "interpolation": "bicubic",
111
- "resize_mode": "squash",
112
- "size": 384
113
  },
114
- "torch_dtype": "float16",
115
- "transformers_version": "4.36.2",
116
  "vision_embed_dim": 1152,
117
- "vision_tower_name": "ViT-SO400M-14-SigLIP-384",
118
  "vocab_size": 51200
119
- }
 
1
  {
2
+ "auto_map": {
3
+ "AutoConfig": "configuration_llava.LlavaConfig",
4
+ "AutoModel": "modeling_llava.LlavaForCausalLM",
5
+ "AutoModelForCausalLM": "modeling_llava.LlavaForCausalLM"
6
+ },
7
+ "model_type": "mc-llava",
8
  "ignore_index": -100,
9
  "image_token_index": 50297,
 
 
10
  "projector_hidden_act": "gelu",
11
+ "projector_tokens_num": 1,
12
  "text_config": {
13
+ "_name_or_path": "vince62s/phi-2-psy",
 
14
  "add_cross_attention": false,
15
  "architectures": [
16
  "PhiForCausalLM"
17
  ],
18
+ "attention_dropout": 0.0,
19
  "auto_map": {
20
+ "AutoConfig": "vince62s/phi-2-psy--configuration_phi.PhiConfig",
21
+ "AutoModelForCausalLM": "vince62s/phi-2-psy--modeling_phi.PhiForCausalLM"
22
  },
23
  "bad_words_ids": null,
24
  "begin_suppress_tokens": null,
 
34
  "eos_token_id": null,
35
  "exponential_decay_length_penalty": null,
36
  "finetuning_task": null,
 
 
37
  "forced_bos_token_id": null,
38
  "forced_eos_token_id": null,
39
+ "hidden_act": "gelu_new",
40
+ "hidden_size": 2560,
41
  "id2label": {
42
  "0": "LABEL_0",
43
  "1": "LABEL_1"
44
  },
 
45
  "initializer_range": 0.02,
46
+ "intermediate_size": 10240,
47
  "is_decoder": false,
48
  "is_encoder_decoder": false,
49
  "label2id": {
50
  "LABEL_0": 0,
51
  "LABEL_1": 1
52
  },
53
+ "layer_norm_eps": 1e-05,
54
  "length_penalty": 1.0,
55
  "max_length": 20,
56
+ "max_position_embeddings": 2048,
57
  "min_length": 0,
58
+ "model_type": "phi",
 
 
 
 
 
 
59
  "no_repeat_ngram_size": 0,
60
+ "num_attention_heads": 32,
61
  "num_beam_groups": 1,
62
  "num_beams": 1,
63
+ "num_hidden_layers": 32,
64
+ "num_key_value_heads": 32,
65
  "num_return_sequences": 1,
66
  "output_attentions": false,
67
  "output_hidden_states": false,
68
  "output_scores": false,
69
  "pad_token_id": null,
70
+ "partial_rotary_factor": 0.4,
71
  "prefix": null,
72
  "problem_type": null,
73
  "pruned_heads": {},
74
+ "qk_layernorm": false,
75
  "remove_invalid_values": false,
76
  "repetition_penalty": 1.0,
77
  "resid_pdrop": 0.1,
78
  "return_dict": true,
79
  "return_dict_in_generate": false,
80
+ "rope_scaling": null,
81
+ "rope_theta": 10000.0,
82
  "sep_token_id": null,
83
  "suppress_tokens": null,
84
  "task_specific_params": null,
 
89
  "tokenizer_class": null,
90
  "top_k": 50,
91
  "top_p": 1.0,
92
+ "torch_dtype": "bfloat16",
93
  "torchscript": false,
94
  "typical_p": 1.0,
95
  "use_bfloat16": false,
96
+ "use_cache": true,
97
  "vocab_size": 51200
98
  },
99
+ "torch_dtype": "bfloat16",
100
+ "transformers_version": "4.37.2",
101
+ "vision_config": {
102
+ "hidden_size": 1152,
103
+ "image_size": 384,
104
+ "intermediate_size": 4304,
105
+ "model_type": "siglip_vision_model",
106
+ "num_attention_heads": 16,
107
+ "num_hidden_layers": 27,
108
+ "patch_size": 14
 
 
 
 
109
  },
 
 
110
  "vision_embed_dim": 1152,
111
+ "vision_tower_name": "google/siglip-so400m-patch14-384",
112
  "vocab_size": 51200
113
+ }
configuration_llava.py CHANGED
@@ -1,18 +1,107 @@
1
- # coding=utf-8
2
-
3
  from transformers.configuration_utils import PretrainedConfig
4
- from open_clip import get_model_config
5
- from configuration_phi import PhiConfig
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
 
8
  class LlavaConfig(PretrainedConfig):
9
- model_type = "llava"
10
  is_composition = False
11
 
12
  def __init__(
13
  self,
14
  text_config=None,
15
- vision_tower_name="ViT-SO400M-14-SigLIP-384",
16
  ignore_index=-100,
17
  image_token_index=50297,
18
  projector_hidden_act="gelu",
@@ -26,16 +115,17 @@ class LlavaConfig(PretrainedConfig):
26
  self.projector_tokens_num = projector_tokens_num
27
  self.vocab_size = vocab_size
28
 
29
- self.vision_tower_name = vision_tower_name
30
- vision_config = get_model_config(vision_tower_name)
31
- self.vision_embed_dim = vision_config["embed_dim"]
32
-
33
- self.vocab_size = self.vocab_size
34
-
35
  self.text_config = text_config
36
  if isinstance(self.text_config, dict):
37
- text_config["model_type"] = text_config["model_type"] if "model_type" in text_config else "llama"
 
 
38
  self.text_config = PhiConfig(**text_config)
39
  self.vocab_size = self.text_config.vocab_size
40
 
41
- super().__init__(**kwargs)
 
 
 
 
 
 
 
 
1
  from transformers.configuration_utils import PretrainedConfig
2
+ from transformers.utils import logging
3
+ from transformers import SiglipVisionConfig
4
+
5
+
6
+ logger = logging.get_logger(__name__)
7
+
8
+
9
+ class PhiConfig(PretrainedConfig):
10
+ model_type = "phi"
11
+ keys_to_ignore_at_inference = ["past_key_values"]
12
+
13
+ def __init__(
14
+ self,
15
+ vocab_size=51200,
16
+ hidden_size=2048,
17
+ intermediate_size=8192,
18
+ num_hidden_layers=24,
19
+ num_attention_heads=32,
20
+ num_key_value_heads=None,
21
+ resid_pdrop=0.0,
22
+ embd_pdrop=0.0,
23
+ attention_dropout=0.0,
24
+ hidden_act="gelu_new",
25
+ max_position_embeddings=2048,
26
+ initializer_range=0.02,
27
+ layer_norm_eps=1e-5,
28
+ use_cache=True,
29
+ tie_word_embeddings=False,
30
+ rope_theta=10000.0,
31
+ rope_scaling=None,
32
+ partial_rotary_factor=0.5,
33
+ qk_layernorm=False,
34
+ bos_token_id=1,
35
+ eos_token_id=2,
36
+ **kwargs,
37
+ ):
38
+ self.vocab_size = vocab_size
39
+ self.hidden_size = hidden_size
40
+ self.intermediate_size = intermediate_size
41
+ self.num_hidden_layers = num_hidden_layers
42
+ self.num_attention_heads = num_attention_heads
43
+
44
+ if num_key_value_heads is None:
45
+ num_key_value_heads = num_attention_heads
46
+
47
+ self.num_key_value_heads = num_key_value_heads
48
+ self.resid_pdrop = resid_pdrop
49
+ self.embd_pdrop = embd_pdrop
50
+ self.attention_dropout = attention_dropout
51
+ self.hidden_act = hidden_act
52
+ self.max_position_embeddings = max_position_embeddings
53
+ self.initializer_range = initializer_range
54
+ self.layer_norm_eps = layer_norm_eps
55
+ self.use_cache = use_cache
56
+ self.rope_theta = rope_theta
57
+ self.rope_scaling = rope_scaling
58
+ self.partial_rotary_factor = partial_rotary_factor
59
+ self.qk_layernorm = qk_layernorm
60
+ self._rope_scaling_validation()
61
+
62
+ super().__init__(
63
+ bos_token_id=bos_token_id,
64
+ eos_token_id=eos_token_id,
65
+ tie_word_embeddings=tie_word_embeddings,
66
+ **kwargs,
67
+ )
68
+
69
+ def _rope_scaling_validation(self):
70
+ """
71
+ Validate the `rope_scaling` configuration.
72
+ """
73
+ if self.rope_scaling is None:
74
+ return
75
+
76
+ if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
77
+ raise ValueError(
78
+ "`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, "
79
+ f"got {self.rope_scaling}"
80
+ )
81
+ rope_scaling_type = self.rope_scaling.get("type", None)
82
+ rope_scaling_factor = self.rope_scaling.get("factor", None)
83
+ if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
84
+ raise ValueError(
85
+ f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
86
+ )
87
+ if (
88
+ rope_scaling_factor is None
89
+ or not isinstance(rope_scaling_factor, float)
90
+ or rope_scaling_factor <= 1.0
91
+ ):
92
+ raise ValueError(
93
+ f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}"
94
+ )
95
 
96
 
97
  class LlavaConfig(PretrainedConfig):
98
+ model_type = "mc-llava"
99
  is_composition = False
100
 
101
  def __init__(
102
  self,
103
  text_config=None,
104
+ vision_config=None,
105
  ignore_index=-100,
106
  image_token_index=50297,
107
  projector_hidden_act="gelu",
 
115
  self.projector_tokens_num = projector_tokens_num
116
  self.vocab_size = vocab_size
117
 
 
 
 
 
 
 
118
  self.text_config = text_config
119
  if isinstance(self.text_config, dict):
120
+ text_config["model_type"] = (
121
+ text_config["model_type"] if "model_type" in text_config else "phi"
122
+ )
123
  self.text_config = PhiConfig(**text_config)
124
  self.vocab_size = self.text_config.vocab_size
125
 
126
+ self.vision_config = vision_config
127
+ if isinstance(self.vision_config, dict):
128
+ self.vision_config = SiglipVisionConfig(**vision_config)
129
+ self.vision_embed_dim = self.vision_config.hidden_size
130
+
131
+ super().__init__(**kwargs)
generation_config.json CHANGED
@@ -1,5 +1,4 @@
1
  {
2
  "_from_model_config": true,
3
- "transformers_version": "4.36.2",
4
- "use_cache": false
5
  }
 
1
  {
2
  "_from_model_config": true,
3
+ "transformers_version": "4.37.2"
 
4
  }
model-00001-of-00002.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:4a26ade091346f7adf46838555315d0ae89b3e704485be64f5cd6490f17f1f73
3
- size 4989958040
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0fa70d660b4f81b87cc3a93e89e8da38f6abf08dae6f51b2bef772f420799cec
3
+ size 4969060728
model-00002-of-00002.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:3bf488b78e5b0b9929d4dfa3afb8760759fb610ac8ae7835797f3dc5670272f0
3
- size 1520997992
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2e45b58c3bebb723eb7493f75356e75e090f0c62865e49a19e19ba73b80241a2
3
+ size 1468562584
model.safetensors.index.json CHANGED
The diff for this file is too large to render. See raw diff
 
modeling_llava.py CHANGED
@@ -1,17 +1,1304 @@
1
  # coding=utf-8
 
2
  from dataclasses import dataclass
3
  from typing import List, Optional, Tuple, Union
4
 
5
  import torch
 
6
  import torch.utils.checkpoint
 
 
7
  from torch import nn
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
- from transformers import PreTrainedModel
10
- from transformers.modeling_outputs import ModelOutput
 
 
 
11
 
12
- from modeling_phi import PhiForCausalLM
13
- from configuration_llava import LlavaConfig
14
- from open_clip import create_model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
 
17
  @dataclass
@@ -24,22 +1311,80 @@ class LlavaCausalLMOutputWithPast(ModelOutput):
24
  image_features: Optional[torch.FloatTensor] = None
25
 
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  class LlavaMultiModalProjector(nn.Module):
28
  def __init__(self, config: LlavaConfig):
29
  super().__init__()
30
 
31
  self.linear_1 = nn.Linear(
32
  config.vision_embed_dim,
33
- config.text_config.n_embd * config.projector_tokens_num,
34
  bias=True,
35
  )
36
  self.act = nn.GELU()
37
  self.linear_2 = nn.Linear(
38
- config.text_config.n_embd * 5,
39
- config.text_config.n_embd,
40
  bias=True,
41
  )
42
- self.projector_tokens_num = config.projector_tokens_num
43
 
44
  def forward(self, image_features):
45
  hidden_states = self.linear_1(image_features)
@@ -71,11 +1416,10 @@ class LlavaPreTrainedModel(PreTrainedModel):
71
  return self.language_model._supports_sdpa
72
 
73
 
74
- class LlavaForConditionalGeneration(LlavaPreTrainedModel):
75
  def __init__(self, config: LlavaConfig):
76
  super().__init__(config)
77
- clip_model = create_model(config.vision_tower_name)
78
- self.vision_model = clip_model.visual
79
 
80
  self.multi_modal_projector = LlavaMultiModalProjector(config)
81
  self.vocab_size = config.vocab_size
@@ -261,7 +1605,6 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel):
261
 
262
  logits = outputs[0]
263
 
264
-
265
  if not return_dict:
266
  output = (logits,) + outputs[1:]
267
  return output
@@ -283,7 +1626,9 @@ class LlavaForConditionalGeneration(LlavaPreTrainedModel):
283
  image_features=None,
284
  **kwargs,
285
  ):
286
- res = self.language_model.prepare_inputs_for_generation(input_ids, past_key_values, attention_mask, **kwargs)
 
 
287
  input_ids = res["input_ids"]
288
  past_key_values = res["past_key_values"]
289
  attention_mask = res["attention_mask"]
 
1
  # coding=utf-8
2
+ import math
3
  from dataclasses import dataclass
4
  from typing import List, Optional, Tuple, Union
5
 
6
  import torch
7
+ import torch.nn.functional as F
8
  import torch.utils.checkpoint
9
+ from configuration_llava import LlavaConfig
10
+ from configuration_phi import PhiConfig
11
  from torch import nn
12
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
13
+ from transformers import PreTrainedModel, SiglipVisionModel
14
+ from transformers.activations import ACT2FN
15
+ from transformers.cache_utils import Cache, DynamicCache
16
+ from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
17
+ from transformers.modeling_outputs import (
18
+ BaseModelOutputWithPast,
19
+ CausalLMOutputWithPast,
20
+ ModelOutput,
21
+ SequenceClassifierOutputWithPast,
22
+ TokenClassifierOutput,
23
+ )
24
+ from transformers.utils import (
25
+ is_flash_attn_2_available,
26
+ is_flash_attn_greater_or_equal_2_10,
27
+ logging,
28
+ )
29
 
30
+ try:
31
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
32
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
33
+ except Exception as exp:
34
+ print(exp)
35
 
36
+
37
+ logger = logging.get_logger(__name__)
38
+
39
+
40
+ # Copied from transformers.models.llama.modeling_llama._get_unpad_data
41
+ def _get_unpad_data(attention_mask):
42
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
43
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
44
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
45
+ cu_seqlens = F.pad(
46
+ torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)
47
+ )
48
+ return (
49
+ indices,
50
+ cu_seqlens,
51
+ max_seqlen_in_batch,
52
+ )
53
+
54
+
55
+ # Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Phi
56
+ class PhiRotaryEmbedding(nn.Module):
57
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
58
+ super().__init__()
59
+
60
+ self.dim = dim
61
+ self.max_position_embeddings = max_position_embeddings
62
+ self.base = base
63
+ inv_freq = 1.0 / (
64
+ self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)
65
+ )
66
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
67
+
68
+ # Build here to make `torch.jit.trace` work.
69
+ self._set_cos_sin_cache(
70
+ seq_len=max_position_embeddings,
71
+ device=self.inv_freq.device,
72
+ dtype=torch.get_default_dtype(),
73
+ )
74
+
75
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
76
+ self.max_seq_len_cached = seq_len
77
+ t = torch.arange(
78
+ self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype
79
+ )
80
+
81
+ freqs = torch.outer(t, self.inv_freq)
82
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
83
+ emb = torch.cat((freqs, freqs), dim=-1)
84
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
85
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
86
+
87
+ def forward(self, x, seq_len=None):
88
+ # x: [bs, num_attention_heads, seq_len, head_size]
89
+ if seq_len > self.max_seq_len_cached:
90
+ self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
91
+
92
+ return (
93
+ self.cos_cached[:seq_len].to(dtype=x.dtype),
94
+ self.sin_cached[:seq_len].to(dtype=x.dtype),
95
+ )
96
+
97
+
98
+ # Copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->Phi
99
+ class PhiLinearScalingRotaryEmbedding(PhiRotaryEmbedding):
100
+ """PhiRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
101
+
102
+ def __init__(
103
+ self,
104
+ dim,
105
+ max_position_embeddings=2048,
106
+ base=10000,
107
+ device=None,
108
+ scaling_factor=1.0,
109
+ ):
110
+ self.scaling_factor = scaling_factor
111
+ super().__init__(dim, max_position_embeddings, base, device)
112
+
113
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
114
+ self.max_seq_len_cached = seq_len
115
+ t = torch.arange(
116
+ self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype
117
+ )
118
+ t = t / self.scaling_factor
119
+
120
+ freqs = torch.outer(t, self.inv_freq)
121
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
122
+ emb = torch.cat((freqs, freqs), dim=-1)
123
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
124
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
125
+
126
+
127
+ # Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->Phi
128
+ class PhiDynamicNTKScalingRotaryEmbedding(PhiRotaryEmbedding):
129
+ """PhiRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
130
+
131
+ def __init__(
132
+ self,
133
+ dim,
134
+ max_position_embeddings=2048,
135
+ base=10000,
136
+ device=None,
137
+ scaling_factor=1.0,
138
+ ):
139
+ self.scaling_factor = scaling_factor
140
+ super().__init__(dim, max_position_embeddings, base, device)
141
+
142
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
143
+ self.max_seq_len_cached = seq_len
144
+
145
+ if seq_len > self.max_position_embeddings:
146
+ base = self.base * (
147
+ (self.scaling_factor * seq_len / self.max_position_embeddings)
148
+ - (self.scaling_factor - 1)
149
+ ) ** (self.dim / (self.dim - 2))
150
+ inv_freq = 1.0 / (
151
+ base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)
152
+ )
153
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
154
+
155
+ t = torch.arange(
156
+ self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype
157
+ )
158
+
159
+ freqs = torch.outer(t, self.inv_freq)
160
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
161
+ emb = torch.cat((freqs, freqs), dim=-1)
162
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
163
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
164
+
165
+
166
+ # Copied from transformers.models.llama.modeling_llama.rotate_half
167
+ def rotate_half(x):
168
+ """Rotates half the hidden dims of the input."""
169
+ x1 = x[..., : x.shape[-1] // 2]
170
+ x2 = x[..., x.shape[-1] // 2 :]
171
+ return torch.cat((-x2, x1), dim=-1)
172
+
173
+
174
+ # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
175
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
176
+ cos = cos[position_ids].unsqueeze(unsqueeze_dim)
177
+ sin = sin[position_ids].unsqueeze(unsqueeze_dim)
178
+ q_embed = (q * cos) + (rotate_half(q) * sin)
179
+ k_embed = (k * cos) + (rotate_half(k) * sin)
180
+ return q_embed, k_embed
181
+
182
+
183
+ # Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->Phi
184
+ class PhiMLP(nn.Module):
185
+ def __init__(self, config):
186
+ super().__init__()
187
+ self.config = config
188
+ self.activation_fn = ACT2FN[config.hidden_act]
189
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
190
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
191
+
192
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
193
+ hidden_states = self.fc1(hidden_states)
194
+ hidden_states = self.activation_fn(hidden_states)
195
+ hidden_states = self.fc2(hidden_states)
196
+ return hidden_states
197
+
198
+
199
+ # Copied from transformers.models.llama.modeling_llama.repeat_kv with llama->phi
200
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
201
+ """
202
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
203
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
204
+ """
205
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
206
+ if n_rep == 1:
207
+ return hidden_states
208
+ hidden_states = hidden_states[:, :, None, :, :].expand(
209
+ batch, num_key_value_heads, n_rep, slen, head_dim
210
+ )
211
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
212
+
213
+
214
+ class PhiAttention(nn.Module):
215
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
216
+
217
+ def __init__(self, config: PhiConfig, layer_idx: Optional[int] = None):
218
+ super().__init__()
219
+ self.config = config
220
+ self.layer_idx = layer_idx
221
+ if layer_idx is None:
222
+ logger.warning_once(
223
+ f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
224
+ "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
225
+ "when creating this class."
226
+ )
227
+
228
+ self.attention_dropout = config.attention_dropout
229
+ self.hidden_size = config.hidden_size
230
+ self.num_heads = config.num_attention_heads
231
+ self.head_dim = self.hidden_size // self.num_heads
232
+ self.num_key_value_heads = config.num_key_value_heads
233
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
234
+ self.max_position_embeddings = config.max_position_embeddings
235
+ self.rope_theta = config.rope_theta
236
+ self.partial_rotary_factor = config.partial_rotary_factor
237
+ self.is_causal = True
238
+
239
+ if (self.head_dim * self.num_heads) != self.hidden_size:
240
+ raise ValueError(
241
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
242
+ f" and `num_heads`: {self.num_heads})."
243
+ )
244
+
245
+ self.q_proj = nn.Linear(
246
+ self.hidden_size, self.num_heads * self.head_dim, bias=True
247
+ )
248
+ self.k_proj = nn.Linear(
249
+ self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True
250
+ )
251
+ self.v_proj = nn.Linear(
252
+ self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True
253
+ )
254
+ self.dense = nn.Linear(
255
+ self.num_heads * self.head_dim, self.hidden_size, bias=True
256
+ )
257
+
258
+ self.qk_layernorm = config.qk_layernorm
259
+ if self.qk_layernorm:
260
+ self.q_layernorm = nn.LayerNorm(
261
+ config.hidden_size // self.num_heads,
262
+ eps=config.layer_norm_eps,
263
+ elementwise_affine=True,
264
+ )
265
+ self.k_layernorm = nn.LayerNorm(
266
+ config.hidden_size // self.num_heads,
267
+ eps=config.layer_norm_eps,
268
+ elementwise_affine=True,
269
+ )
270
+
271
+ self._init_rope()
272
+
273
+ def _init_rope(self):
274
+ if self.config.rope_scaling is None:
275
+ self.rotary_emb = PhiRotaryEmbedding(
276
+ int(self.partial_rotary_factor * self.head_dim),
277
+ max_position_embeddings=self.max_position_embeddings,
278
+ base=self.rope_theta,
279
+ )
280
+ else:
281
+ scaling_type = self.config.rope_scaling["type"]
282
+ scaling_factor = self.config.rope_scaling["factor"]
283
+ if scaling_type == "linear":
284
+ self.rotary_emb = PhiLinearScalingRotaryEmbedding(
285
+ int(self.partial_rotary_factor * self.head_dim),
286
+ max_position_embeddings=self.max_position_embeddings,
287
+ scaling_factor=scaling_factor,
288
+ base=self.rope_theta,
289
+ )
290
+ elif scaling_type == "dynamic":
291
+ self.rotary_emb = PhiDynamicNTKScalingRotaryEmbedding(
292
+ int(self.partial_rotary_factor * self.head_dim),
293
+ max_position_embeddings=self.max_position_embeddings,
294
+ scaling_factor=scaling_factor,
295
+ base=self.rope_theta,
296
+ )
297
+ else:
298
+ raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
299
+
300
+ # Phi-2 has an attention overflow issue (with FP16) and requires autocast to be disabled
301
+ @torch.autocast("cpu", enabled=False)
302
+ @torch.autocast("cuda", enabled=False)
303
+ def forward(
304
+ self,
305
+ hidden_states: torch.Tensor,
306
+ attention_mask: Optional[torch.Tensor] = None,
307
+ position_ids: Optional[torch.LongTensor] = None,
308
+ past_key_value: Optional[Cache] = None,
309
+ output_attentions: bool = False,
310
+ use_cache: bool = False,
311
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
312
+ bsz, q_len, _ = hidden_states.size()
313
+
314
+ query_states = self.q_proj(hidden_states)
315
+ key_states = self.k_proj(hidden_states)
316
+ value_states = self.v_proj(hidden_states)
317
+
318
+ if self.qk_layernorm:
319
+ query_states = self.q_layernorm(query_states)
320
+ key_states = self.k_layernorm(key_states)
321
+
322
+ query_states = query_states.view(
323
+ bsz, q_len, self.num_heads, self.head_dim
324
+ ).transpose(1, 2)
325
+ key_states = key_states.view(
326
+ bsz, q_len, self.num_key_value_heads, self.head_dim
327
+ ).transpose(1, 2)
328
+ value_states = value_states.view(
329
+ bsz, q_len, self.num_key_value_heads, self.head_dim
330
+ ).transpose(1, 2)
331
+
332
+ kv_seq_len = key_states.shape[-2]
333
+ if past_key_value is not None:
334
+ if self.layer_idx is None:
335
+ raise ValueError(
336
+ f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
337
+ "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
338
+ "with a layer index."
339
+ )
340
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
341
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
342
+
343
+ # Partial rotary embedding
344
+ query_rot, query_pass = (
345
+ query_states[..., : self.rotary_emb.dim],
346
+ query_states[..., self.rotary_emb.dim :],
347
+ )
348
+ key_rot, key_pass = (
349
+ key_states[..., : self.rotary_emb.dim],
350
+ key_states[..., self.rotary_emb.dim :],
351
+ )
352
+ # [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor]
353
+ query_rot, key_rot = apply_rotary_pos_emb(
354
+ query_rot, key_rot, cos, sin, position_ids
355
+ )
356
+
357
+ # [batch_size, seq_length, num_heads, head_dim]
358
+ query_states = torch.cat((query_rot, query_pass), dim=-1)
359
+ key_states = torch.cat((key_rot, key_pass), dim=-1)
360
+
361
+ if past_key_value is not None:
362
+ cache_kwargs = {
363
+ "sin": sin,
364
+ "cos": cos,
365
+ "partial_rotation_size": self.rotary_emb.dim,
366
+ }
367
+ key_states, value_states = past_key_value.update(
368
+ key_states, value_states, self.layer_idx, cache_kwargs
369
+ )
370
+
371
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
372
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
373
+
374
+ # Queries and keys upcast to fp32 is required by Phi-2 to avoid overflow
375
+ attn_weights = torch.matmul(
376
+ query_states.to(torch.float32), key_states.to(torch.float32).transpose(2, 3)
377
+ ) / math.sqrt(self.head_dim)
378
+
379
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
380
+ raise ValueError(
381
+ f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
382
+ f" {attn_weights.size()}"
383
+ )
384
+
385
+ if attention_mask is not None:
386
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
387
+ raise ValueError(
388
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
389
+ )
390
+ attn_weights = attn_weights + attention_mask
391
+
392
+ # upcast attention to fp32
393
+ attn_weights = nn.functional.softmax(
394
+ attn_weights, dim=-1, dtype=torch.float32
395
+ ).to(value_states.dtype)
396
+ attn_weights = nn.functional.dropout(
397
+ attn_weights, p=self.attention_dropout, training=self.training
398
+ )
399
+
400
+ attn_output = torch.matmul(attn_weights, value_states)
401
+
402
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
403
+ raise ValueError(
404
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
405
+ f" {attn_output.size()}"
406
+ )
407
+
408
+ attn_output = attn_output.transpose(1, 2).contiguous()
409
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
410
+
411
+ attn_output = self.dense(attn_output)
412
+
413
+ if not output_attentions:
414
+ attn_weights = None
415
+
416
+ return attn_output, attn_weights, past_key_value
417
+
418
+
419
+ class PhiFlashAttention2(PhiAttention):
420
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
421
+ def __init__(self, *args, **kwargs):
422
+ super().__init__(*args, **kwargs)
423
+
424
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
425
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
426
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
427
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
428
+
429
+ def forward(
430
+ self,
431
+ hidden_states: torch.Tensor,
432
+ attention_mask: Optional[torch.LongTensor] = None,
433
+ position_ids: Optional[torch.LongTensor] = None,
434
+ past_key_value: Optional[Cache] = None,
435
+ output_attentions: bool = False,
436
+ use_cache: bool = False,
437
+ **kwargs,
438
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
439
+ # PhiFlashAttention2 attention does not support output_attentions
440
+
441
+ output_attentions = False
442
+
443
+ bsz, q_len, _ = hidden_states.size()
444
+
445
+ query_states = self.q_proj(hidden_states)
446
+ key_states = self.k_proj(hidden_states)
447
+ value_states = self.v_proj(hidden_states)
448
+
449
+ if self.qk_layernorm:
450
+ query_states = self.q_layernorm(query_states)
451
+ key_states = self.k_layernorm(key_states)
452
+
453
+ # Flash attention requires the input to have the shape
454
+ # batch_size x seq_length x head_dim x hidden_dim
455
+ # therefore we just need to keep the original shape
456
+ query_states = query_states.view(
457
+ bsz, q_len, self.num_heads, self.head_dim
458
+ ).transpose(1, 2)
459
+ key_states = key_states.view(
460
+ bsz, q_len, self.num_key_value_heads, self.head_dim
461
+ ).transpose(1, 2)
462
+ value_states = value_states.view(
463
+ bsz, q_len, self.num_key_value_heads, self.head_dim
464
+ ).transpose(1, 2)
465
+
466
+ kv_seq_len = key_states.shape[-2]
467
+ if past_key_value is not None:
468
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
469
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
470
+
471
+ # Partial rotary embedding
472
+ query_rot, query_pass = (
473
+ query_states[..., : self.rotary_emb.dim],
474
+ query_states[..., self.rotary_emb.dim :],
475
+ )
476
+ key_rot, key_pass = (
477
+ key_states[..., : self.rotary_emb.dim],
478
+ key_states[..., self.rotary_emb.dim :],
479
+ )
480
+ # [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor]
481
+ query_rot, key_rot = apply_rotary_pos_emb(
482
+ query_rot, key_rot, cos, sin, position_ids
483
+ )
484
+
485
+ # [batch_size, seq_length, num_heads, head_dim]
486
+ query_states = torch.cat((query_rot, query_pass), dim=-1)
487
+ key_states = torch.cat((key_rot, key_pass), dim=-1)
488
+
489
+ if past_key_value is not None:
490
+ cache_kwargs = {
491
+ "sin": sin,
492
+ "cos": cos,
493
+ "partial_rotation_size": self.rotary_emb.dim,
494
+ }
495
+ key_states, value_states = past_key_value.update(
496
+ key_states, value_states, self.layer_idx, cache_kwargs
497
+ )
498
+
499
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
500
+ # to be able to avoid many of these transpose/reshape/view.
501
+ query_states = query_states.transpose(1, 2)
502
+ key_states = key_states.transpose(1, 2)
503
+ value_states = value_states.transpose(1, 2)
504
+
505
+ attn_dropout = self.attention_dropout if self.training else 0.0
506
+
507
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
508
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
509
+ # cast them back in the correct dtype just to be sure everything works as expected.
510
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
511
+ # in fp32.
512
+
513
+ if query_states.dtype == torch.float32:
514
+ if torch.is_autocast_enabled():
515
+ target_dtype = torch.get_autocast_gpu_dtype()
516
+ # Handle the case where the model is quantized
517
+ elif hasattr(self.config, "_pre_quantization_dtype"):
518
+ target_dtype = self.config._pre_quantization_dtype
519
+ else:
520
+ target_dtype = self.q_proj.weight.dtype
521
+
522
+ logger.warning_once(
523
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
524
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
525
+ f" {target_dtype}."
526
+ )
527
+
528
+ query_states = query_states.to(target_dtype)
529
+ key_states = key_states.to(target_dtype)
530
+ value_states = value_states.to(target_dtype)
531
+
532
+ attn_output = self._flash_attention_forward(
533
+ query_states,
534
+ key_states,
535
+ value_states,
536
+ attention_mask,
537
+ q_len,
538
+ dropout=attn_dropout,
539
+ softmax_scale=None,
540
+ )
541
+
542
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
543
+ attn_output = self.dense(attn_output)
544
+
545
+ if not output_attentions:
546
+ attn_weights = None
547
+
548
+ return attn_output, attn_weights, past_key_value
549
+
550
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward
551
+ def _flash_attention_forward(
552
+ self,
553
+ query_states,
554
+ key_states,
555
+ value_states,
556
+ attention_mask,
557
+ query_length,
558
+ dropout=0.0,
559
+ softmax_scale=None,
560
+ ):
561
+ if not self._flash_attn_uses_top_left_mask:
562
+ causal = self.is_causal
563
+ else:
564
+ # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
565
+ causal = self.is_causal and query_length != 1
566
+
567
+ # Contains at least one padding token in the sequence
568
+ if attention_mask is not None:
569
+ batch_size = query_states.shape[0]
570
+ (
571
+ query_states,
572
+ key_states,
573
+ value_states,
574
+ indices_q,
575
+ cu_seq_lens,
576
+ max_seq_lens,
577
+ ) = self._upad_input(
578
+ query_states, key_states, value_states, attention_mask, query_length
579
+ )
580
+
581
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
582
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
583
+
584
+ attn_output_unpad = flash_attn_varlen_func(
585
+ query_states,
586
+ key_states,
587
+ value_states,
588
+ cu_seqlens_q=cu_seqlens_q,
589
+ cu_seqlens_k=cu_seqlens_k,
590
+ max_seqlen_q=max_seqlen_in_batch_q,
591
+ max_seqlen_k=max_seqlen_in_batch_k,
592
+ dropout_p=dropout,
593
+ softmax_scale=softmax_scale,
594
+ causal=causal,
595
+ )
596
+
597
+ attn_output = pad_input(
598
+ attn_output_unpad, indices_q, batch_size, query_length
599
+ )
600
+ else:
601
+ attn_output = flash_attn_func(
602
+ query_states,
603
+ key_states,
604
+ value_states,
605
+ dropout,
606
+ softmax_scale=softmax_scale,
607
+ causal=causal,
608
+ )
609
+
610
+ return attn_output
611
+
612
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input
613
+ def _upad_input(
614
+ self, query_layer, key_layer, value_layer, attention_mask, query_length
615
+ ):
616
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
617
+ batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
618
+
619
+ key_layer = index_first_axis(
620
+ key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),
621
+ indices_k,
622
+ )
623
+ value_layer = index_first_axis(
624
+ value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),
625
+ indices_k,
626
+ )
627
+ if query_length == kv_seq_len:
628
+ query_layer = index_first_axis(
629
+ query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim),
630
+ indices_k,
631
+ )
632
+ cu_seqlens_q = cu_seqlens_k
633
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
634
+ indices_q = indices_k
635
+ elif query_length == 1:
636
+ max_seqlen_in_batch_q = 1
637
+ cu_seqlens_q = torch.arange(
638
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
639
+ ) # There is a memcpy here, that is very bad.
640
+ indices_q = cu_seqlens_q[:-1]
641
+ query_layer = query_layer.squeeze(1)
642
+ else:
643
+ # The -q_len: slice assumes left padding.
644
+ attention_mask = attention_mask[:, -query_length:]
645
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(
646
+ query_layer, attention_mask
647
+ )
648
+
649
+ return (
650
+ query_layer,
651
+ key_layer,
652
+ value_layer,
653
+ indices_q,
654
+ (cu_seqlens_q, cu_seqlens_k),
655
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
656
+ )
657
+
658
+
659
+ PHI_ATTENTION_CLASSES = {
660
+ "flash_attention_2": PhiFlashAttention2,
661
+ "eager": PhiAttention,
662
+ }
663
+
664
+
665
+ class PhiDecoderLayer(nn.Module):
666
+ def __init__(self, config: PhiConfig, layer_idx: int):
667
+ super().__init__()
668
+ if is_flash_attn_2_available():
669
+ config._attn_implementation = "flash_attention_2"
670
+ self.self_attn = PHI_ATTENTION_CLASSES[config._attn_implementation](
671
+ config, layer_idx=layer_idx
672
+ )
673
+ self.mlp = PhiMLP(config)
674
+ self.input_layernorm = nn.LayerNorm(
675
+ config.hidden_size, eps=config.layer_norm_eps
676
+ )
677
+ self.resid_dropout = nn.Dropout(config.resid_pdrop)
678
+
679
+ def forward(
680
+ self,
681
+ hidden_states: torch.Tensor,
682
+ attention_mask: Optional[torch.Tensor] = None,
683
+ position_ids: Optional[torch.LongTensor] = None,
684
+ output_attentions: Optional[bool] = False,
685
+ use_cache: Optional[bool] = False,
686
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
687
+ ) -> Tuple[
688
+ torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
689
+ ]:
690
+ residual = hidden_states
691
+
692
+ hidden_states = self.input_layernorm(hidden_states)
693
+
694
+ # Self Attention
695
+ attn_outputs, self_attn_weights, present_key_value = self.self_attn(
696
+ hidden_states=hidden_states,
697
+ attention_mask=attention_mask,
698
+ position_ids=position_ids,
699
+ past_key_value=past_key_value,
700
+ output_attentions=output_attentions,
701
+ use_cache=use_cache,
702
+ )
703
+ attn_outputs = self.resid_dropout(attn_outputs)
704
+
705
+ feed_forward_hidden_states = self.resid_dropout(self.mlp(hidden_states))
706
+ hidden_states = attn_outputs + feed_forward_hidden_states + residual
707
+ outputs = (hidden_states,)
708
+
709
+ if output_attentions:
710
+ outputs += (self_attn_weights,)
711
+
712
+ if use_cache:
713
+ outputs += (present_key_value,)
714
+
715
+ return outputs
716
+
717
+
718
+ class PhiPreTrainedModel(PreTrainedModel):
719
+ config_class = PhiConfig
720
+ base_model_prefix = "model"
721
+ supports_gradient_checkpointing = True
722
+ _no_split_modules = ["PhiDecoderLayer"]
723
+ _skip_keys_device_placement = "past_key_values"
724
+ _supports_flash_attn_2 = True
725
+ _supports_cache_class = True
726
+
727
+
728
+ class PhiModel(PhiPreTrainedModel):
729
+ """
730
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`PhiDecoderLayer`]
731
+
732
+ Args:
733
+ config: PhiConfig
734
+ """
735
+
736
+ def __init__(self, config: PhiConfig):
737
+ super().__init__(config)
738
+ self.padding_idx = config.pad_token_id
739
+ self.vocab_size = config.vocab_size
740
+
741
+ self.embed_tokens = nn.Embedding(
742
+ config.vocab_size, config.hidden_size, self.padding_idx
743
+ )
744
+ self.embed_dropout = nn.Dropout(config.embd_pdrop)
745
+ self.layers = nn.ModuleList(
746
+ [
747
+ PhiDecoderLayer(config, layer_idx)
748
+ for layer_idx in range(config.num_hidden_layers)
749
+ ]
750
+ )
751
+ self.final_layernorm = nn.LayerNorm(
752
+ config.hidden_size, eps=config.layer_norm_eps
753
+ )
754
+ self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
755
+
756
+ self.gradient_checkpointing = False
757
+ # Initialize weights and apply final processing
758
+ self.post_init()
759
+
760
+ def get_input_embeddings(self):
761
+ return self.embed_tokens
762
+
763
+ def set_input_embeddings(self, value):
764
+ self.embed_tokens = value
765
+
766
+ def forward(
767
+ self,
768
+ input_ids: torch.LongTensor = None,
769
+ attention_mask: Optional[torch.Tensor] = None,
770
+ position_ids: Optional[torch.LongTensor] = None,
771
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
772
+ inputs_embeds: Optional[torch.FloatTensor] = None,
773
+ use_cache: Optional[bool] = None,
774
+ output_attentions: Optional[bool] = None,
775
+ output_hidden_states: Optional[bool] = None,
776
+ return_dict: Optional[bool] = None,
777
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
778
+ output_attentions = (
779
+ output_attentions
780
+ if output_attentions is not None
781
+ else self.config.output_attentions
782
+ )
783
+ output_hidden_states = (
784
+ output_hidden_states
785
+ if output_hidden_states is not None
786
+ else self.config.output_hidden_states
787
+ )
788
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
789
+
790
+ return_dict = (
791
+ return_dict if return_dict is not None else self.config.use_return_dict
792
+ )
793
+
794
+ # retrieve input_ids and inputs_embeds
795
+ if input_ids is not None and inputs_embeds is not None:
796
+ raise ValueError(
797
+ "You cannot specify both input_ids and inputs_embeds at the same time"
798
+ )
799
+ elif input_ids is not None:
800
+ batch_size, seq_length = input_ids.shape[:2]
801
+ elif inputs_embeds is not None:
802
+ batch_size, seq_length = inputs_embeds.shape[:2]
803
+ else:
804
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
805
+
806
+ past_key_values_length = 0
807
+
808
+ if self.gradient_checkpointing and self.training:
809
+ if use_cache:
810
+ logger.warning_once(
811
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
812
+ )
813
+ use_cache = False
814
+
815
+ if use_cache:
816
+ use_legacy_cache = not isinstance(past_key_values, Cache)
817
+ if use_legacy_cache:
818
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
819
+ past_key_values_length = past_key_values.get_usable_length(seq_length)
820
+
821
+ if position_ids is None:
822
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
823
+ position_ids = torch.arange(
824
+ past_key_values_length,
825
+ seq_length + past_key_values_length,
826
+ dtype=torch.long,
827
+ device=device,
828
+ )
829
+ position_ids = position_ids.unsqueeze(0)
830
+
831
+ if inputs_embeds is None:
832
+ inputs_embeds = self.embed_tokens(input_ids)
833
+
834
+ inputs_embeds = self.embed_dropout(inputs_embeds)
835
+
836
+ # Attention mask.
837
+ if self._use_flash_attention_2:
838
+ # 2d mask is passed through the layers
839
+ attention_mask = (
840
+ attention_mask
841
+ if (attention_mask is not None and 0 in attention_mask)
842
+ else None
843
+ )
844
+ else:
845
+ # 4d mask is passed through the layers
846
+ attention_mask = _prepare_4d_causal_attention_mask(
847
+ attention_mask,
848
+ (batch_size, seq_length),
849
+ inputs_embeds,
850
+ past_key_values_length,
851
+ )
852
+
853
+ hidden_states = inputs_embeds
854
+
855
+ # decoder layers
856
+ all_hidden_states = () if output_hidden_states else None
857
+ all_self_attns = () if output_attentions else None
858
+ next_decoder_cache = None
859
+
860
+ for decoder_layer in self.layers:
861
+ if output_hidden_states:
862
+ all_hidden_states += (hidden_states,)
863
+
864
+ if self.gradient_checkpointing and self.training:
865
+ layer_outputs = self._gradient_checkpointing_func(
866
+ decoder_layer.__call__,
867
+ hidden_states,
868
+ attention_mask,
869
+ position_ids,
870
+ past_key_values,
871
+ output_attentions,
872
+ )
873
+ else:
874
+ layer_outputs = decoder_layer(
875
+ hidden_states,
876
+ attention_mask=attention_mask,
877
+ position_ids=position_ids,
878
+ past_key_value=past_key_values,
879
+ output_attentions=output_attentions,
880
+ use_cache=use_cache,
881
+ )
882
+
883
+ hidden_states = layer_outputs[0]
884
+
885
+ if use_cache:
886
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
887
+
888
+ if output_attentions:
889
+ all_self_attns += (layer_outputs[1],)
890
+
891
+ hidden_states = self.final_layernorm(hidden_states)
892
+
893
+ # add hidden states from the last decoder layer
894
+ if output_hidden_states:
895
+ all_hidden_states += (hidden_states,)
896
+
897
+ next_cache = None
898
+ if use_cache:
899
+ next_cache = (
900
+ next_decoder_cache.to_legacy_cache()
901
+ if use_legacy_cache
902
+ else next_decoder_cache
903
+ )
904
+ if not return_dict:
905
+ return tuple(
906
+ v
907
+ for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
908
+ if v is not None
909
+ )
910
+ return BaseModelOutputWithPast(
911
+ last_hidden_state=hidden_states,
912
+ past_key_values=next_cache,
913
+ hidden_states=all_hidden_states,
914
+ attentions=all_self_attns,
915
+ )
916
+
917
+
918
+ class PhiForCausalLM(PhiPreTrainedModel):
919
+ _tied_weights_keys = ["lm_head.weight"]
920
+
921
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.__init__ with Llama->Phi,bias=False->bias=True
922
+ def __init__(self, config):
923
+ super().__init__(config)
924
+ self.model = PhiModel(config)
925
+ self.vocab_size = config.vocab_size
926
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=True)
927
+
928
+ # Initialize weights and apply final processing
929
+ self.post_init()
930
+
931
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_input_embeddings
932
+ def get_input_embeddings(self):
933
+ return self.model.embed_tokens
934
+
935
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_input_embeddings
936
+ def set_input_embeddings(self, value):
937
+ self.model.embed_tokens = value
938
+
939
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_output_embeddings
940
+ def get_output_embeddings(self):
941
+ return self.lm_head
942
+
943
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_output_embeddings
944
+ def set_output_embeddings(self, new_embeddings):
945
+ self.lm_head = new_embeddings
946
+
947
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_decoder
948
+ def set_decoder(self, decoder):
949
+ self.model = decoder
950
+
951
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_decoder
952
+ def get_decoder(self):
953
+ return self.model
954
+
955
+ def forward(
956
+ self,
957
+ input_ids: torch.LongTensor = None,
958
+ attention_mask: Optional[torch.Tensor] = None,
959
+ position_ids: Optional[torch.LongTensor] = None,
960
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
961
+ inputs_embeds: Optional[torch.FloatTensor] = None,
962
+ labels: Optional[torch.LongTensor] = None,
963
+ use_cache: Optional[bool] = None,
964
+ output_attentions: Optional[bool] = None,
965
+ output_hidden_states: Optional[bool] = None,
966
+ return_dict: Optional[bool] = None,
967
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
968
+ output_attentions = (
969
+ output_attentions
970
+ if output_attentions is not None
971
+ else self.config.output_attentions
972
+ )
973
+ output_hidden_states = (
974
+ output_hidden_states
975
+ if output_hidden_states is not None
976
+ else self.config.output_hidden_states
977
+ )
978
+ return_dict = (
979
+ return_dict if return_dict is not None else self.config.use_return_dict
980
+ )
981
+
982
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
983
+ outputs = self.model(
984
+ input_ids=input_ids,
985
+ attention_mask=attention_mask,
986
+ position_ids=position_ids,
987
+ past_key_values=past_key_values,
988
+ inputs_embeds=inputs_embeds,
989
+ use_cache=use_cache,
990
+ output_attentions=output_attentions,
991
+ output_hidden_states=output_hidden_states,
992
+ return_dict=return_dict,
993
+ )
994
+
995
+ hidden_states = outputs[0]
996
+ logits = self.lm_head(hidden_states)
997
+ logits = logits.float()
998
+
999
+ loss = None
1000
+ if labels is not None:
1001
+ # Shift so that tokens < n predict n
1002
+ shift_logits = logits[..., :-1, :].contiguous()
1003
+ shift_labels = labels[..., 1:].contiguous()
1004
+ # Flatten the tokens
1005
+ loss_fct = CrossEntropyLoss()
1006
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
1007
+ shift_labels = shift_labels.view(-1)
1008
+ # Enable model parallelism
1009
+ shift_labels = shift_labels.to(shift_logits.device)
1010
+ loss = loss_fct(shift_logits, shift_labels)
1011
+
1012
+ if not return_dict:
1013
+ output = (logits,) + outputs[1:]
1014
+ return (loss,) + output if loss is not None else output
1015
+
1016
+ return CausalLMOutputWithPast(
1017
+ loss=loss,
1018
+ logits=logits,
1019
+ past_key_values=outputs.past_key_values,
1020
+ hidden_states=outputs.hidden_states,
1021
+ attentions=outputs.attentions,
1022
+ )
1023
+
1024
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.prepare_inputs_for_generation
1025
+ def prepare_inputs_for_generation(
1026
+ self,
1027
+ input_ids,
1028
+ past_key_values=None,
1029
+ attention_mask=None,
1030
+ inputs_embeds=None,
1031
+ **kwargs,
1032
+ ):
1033
+ if past_key_values is not None:
1034
+ if isinstance(past_key_values, Cache):
1035
+ cache_length = past_key_values.get_seq_length()
1036
+ past_length = past_key_values.seen_tokens
1037
+ max_cache_length = past_key_values.get_max_length()
1038
+ else:
1039
+ cache_length = past_length = past_key_values[0][0].shape[2]
1040
+ max_cache_length = None
1041
+
1042
+ # Keep only the unprocessed tokens:
1043
+ # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
1044
+ # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
1045
+ # input)
1046
+ if (
1047
+ attention_mask is not None
1048
+ and attention_mask.shape[1] > input_ids.shape[1]
1049
+ ):
1050
+ input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
1051
+ # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
1052
+ # input_ids based on the past_length.
1053
+ elif past_length < input_ids.shape[1]:
1054
+ input_ids = input_ids[:, past_length:]
1055
+ # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
1056
+
1057
+ # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
1058
+ if (
1059
+ max_cache_length is not None
1060
+ and attention_mask is not None
1061
+ and cache_length + input_ids.shape[1] > max_cache_length
1062
+ ):
1063
+ attention_mask = attention_mask[:, -max_cache_length:]
1064
+
1065
+ position_ids = kwargs.get("position_ids", None)
1066
+ if attention_mask is not None and position_ids is None:
1067
+ # create position_ids on the fly for batch generation
1068
+ position_ids = attention_mask.long().cumsum(-1) - 1
1069
+ position_ids.masked_fill_(attention_mask == 0, 1)
1070
+ if past_key_values:
1071
+ position_ids = position_ids[:, -input_ids.shape[1] :]
1072
+
1073
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1074
+ if inputs_embeds is not None and past_key_values is None:
1075
+ model_inputs = {"inputs_embeds": inputs_embeds}
1076
+ else:
1077
+ model_inputs = {"input_ids": input_ids}
1078
+
1079
+ model_inputs.update(
1080
+ {
1081
+ "position_ids": position_ids,
1082
+ "past_key_values": past_key_values,
1083
+ "use_cache": kwargs.get("use_cache"),
1084
+ "attention_mask": attention_mask,
1085
+ }
1086
+ )
1087
+ return model_inputs
1088
+
1089
+ @staticmethod
1090
+ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM._reorder_cache
1091
+ def _reorder_cache(past_key_values, beam_idx):
1092
+ reordered_past = ()
1093
+ for layer_past in past_key_values:
1094
+ reordered_past += (
1095
+ tuple(
1096
+ past_state.index_select(0, beam_idx.to(past_state.device))
1097
+ for past_state in layer_past
1098
+ ),
1099
+ )
1100
+ return reordered_past
1101
+
1102
+
1103
+ class PhiForSequenceClassification(PhiPreTrainedModel):
1104
+ def __init__(self, config):
1105
+ super().__init__(config)
1106
+ self.num_labels = config.num_labels
1107
+ self.model = PhiModel(config)
1108
+ self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
1109
+
1110
+ # Initialize weights and apply final processing
1111
+ self.post_init()
1112
+
1113
+ def get_input_embeddings(self):
1114
+ return self.model.embed_tokens
1115
+
1116
+ def set_input_embeddings(self, value):
1117
+ self.model.embed_tokens = value
1118
+
1119
+ def forward(
1120
+ self,
1121
+ input_ids: torch.LongTensor = None,
1122
+ attention_mask: Optional[torch.Tensor] = None,
1123
+ position_ids: Optional[torch.LongTensor] = None,
1124
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1125
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1126
+ labels: Optional[torch.LongTensor] = None,
1127
+ use_cache: Optional[bool] = None,
1128
+ output_attentions: Optional[bool] = None,
1129
+ output_hidden_states: Optional[bool] = None,
1130
+ return_dict: Optional[bool] = None,
1131
+ ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
1132
+ r"""
1133
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1134
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1135
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1136
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1137
+ """
1138
+ return_dict = (
1139
+ return_dict if return_dict is not None else self.config.use_return_dict
1140
+ )
1141
+
1142
+ model_outputs = self.model(
1143
+ input_ids,
1144
+ attention_mask=attention_mask,
1145
+ position_ids=position_ids,
1146
+ past_key_values=past_key_values,
1147
+ inputs_embeds=inputs_embeds,
1148
+ use_cache=use_cache,
1149
+ output_attentions=output_attentions,
1150
+ output_hidden_states=output_hidden_states,
1151
+ return_dict=return_dict,
1152
+ )
1153
+ hidden_states = model_outputs[0]
1154
+ logits = self.score(hidden_states)
1155
+
1156
+ if input_ids is not None:
1157
+ batch_size = input_ids.shape[0]
1158
+ else:
1159
+ batch_size = inputs_embeds.shape[0]
1160
+
1161
+ if self.config.pad_token_id is None and batch_size != 1:
1162
+ raise ValueError(
1163
+ "Cannot handle batch sizes > 1 if no padding token is defined."
1164
+ )
1165
+ if self.config.pad_token_id is None:
1166
+ sequence_lengths = -1
1167
+ else:
1168
+ if input_ids is not None:
1169
+ # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
1170
+ sequence_lengths = (
1171
+ torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
1172
+ )
1173
+ sequence_lengths = sequence_lengths % input_ids.shape[-1]
1174
+ sequence_lengths = sequence_lengths.to(logits.device)
1175
+ else:
1176
+ sequence_lengths = -1
1177
+
1178
+ pooled_logits = logits[
1179
+ torch.arange(batch_size, device=logits.device), sequence_lengths
1180
+ ]
1181
+
1182
+ loss = None
1183
+ if labels is not None:
1184
+ labels = labels.to(logits.device)
1185
+ if self.config.problem_type is None:
1186
+ if self.num_labels == 1:
1187
+ self.config.problem_type = "regression"
1188
+ elif self.num_labels > 1 and (
1189
+ labels.dtype == torch.long or labels.dtype == torch.int
1190
+ ):
1191
+ self.config.problem_type = "single_label_classification"
1192
+ else:
1193
+ self.config.problem_type = "multi_label_classification"
1194
+
1195
+ if self.config.problem_type == "regression":
1196
+ loss_fct = MSELoss()
1197
+ if self.num_labels == 1:
1198
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
1199
+ else:
1200
+ loss = loss_fct(pooled_logits, labels)
1201
+ elif self.config.problem_type == "single_label_classification":
1202
+ loss_fct = CrossEntropyLoss()
1203
+ loss = loss_fct(
1204
+ pooled_logits.view(-1, self.num_labels), labels.view(-1)
1205
+ )
1206
+ elif self.config.problem_type == "multi_label_classification":
1207
+ loss_fct = BCEWithLogitsLoss()
1208
+ loss = loss_fct(pooled_logits, labels)
1209
+ if not return_dict:
1210
+ output = (pooled_logits,) + model_outputs[1:]
1211
+ return ((loss,) + output) if loss is not None else output
1212
+
1213
+ return SequenceClassifierOutputWithPast(
1214
+ loss=loss,
1215
+ logits=pooled_logits,
1216
+ past_key_values=model_outputs.past_key_values,
1217
+ hidden_states=model_outputs.hidden_states,
1218
+ attentions=model_outputs.attentions,
1219
+ )
1220
+
1221
+
1222
+ class PhiForTokenClassification(PhiPreTrainedModel):
1223
+ def __init__(self, config: PhiConfig):
1224
+ super().__init__(config)
1225
+ self.num_labels = config.num_labels
1226
+
1227
+ self.model = PhiModel(config)
1228
+ if (
1229
+ hasattr(config, "classifier_dropout")
1230
+ and config.classifier_dropout is not None
1231
+ ):
1232
+ classifier_dropout = config.classifier_dropout
1233
+ elif hasattr(config, "hidden_dropout") and config.hidden_dropout is not None:
1234
+ classifier_dropout = config.hidden_dropout
1235
+ else:
1236
+ classifier_dropout = 0.1
1237
+ self.dropout = nn.Dropout(classifier_dropout)
1238
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
1239
+
1240
+ # Initialize weights and apply final processing
1241
+ self.post_init()
1242
+
1243
+ def forward(
1244
+ self,
1245
+ input_ids: Optional[torch.LongTensor] = None,
1246
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
1247
+ attention_mask: Optional[torch.Tensor] = None,
1248
+ inputs_embeds: Optional[torch.Tensor] = None,
1249
+ labels: Optional[torch.Tensor] = None,
1250
+ use_cache: Optional[bool] = None,
1251
+ output_attentions: Optional[bool] = None,
1252
+ output_hidden_states: Optional[bool] = None,
1253
+ return_dict: Optional[bool] = None,
1254
+ **deprecated_arguments,
1255
+ ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
1256
+ r"""
1257
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1258
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1259
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1260
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1261
+ """
1262
+ return_dict = (
1263
+ return_dict if return_dict is not None else self.config.use_return_dict
1264
+ )
1265
+
1266
+ model_outputs = self.model(
1267
+ input_ids,
1268
+ past_key_values=past_key_values,
1269
+ attention_mask=attention_mask,
1270
+ inputs_embeds=inputs_embeds,
1271
+ use_cache=use_cache,
1272
+ output_attentions=output_attentions,
1273
+ output_hidden_states=output_hidden_states,
1274
+ return_dict=return_dict,
1275
+ )
1276
+
1277
+ hidden_states = model_outputs[0]
1278
+ hidden_states = self.dropout(hidden_states)
1279
+ logits = self.classifier(hidden_states)
1280
+
1281
+ loss = None
1282
+ if labels is not None:
1283
+ # move labels to correct device to enable model parallelism
1284
+ labels = labels.to(logits.device)
1285
+ batch_size, seq_length = labels.shape
1286
+ loss_fct = CrossEntropyLoss()
1287
+ loss = loss_fct(
1288
+ logits.view(batch_size * seq_length, self.num_labels),
1289
+ labels.view(batch_size * seq_length),
1290
+ )
1291
+
1292
+ if not return_dict:
1293
+ output = (logits,) + model_outputs[2:]
1294
+ return ((loss,) + output) if loss is not None else output
1295
+
1296
+ return TokenClassifierOutput(
1297
+ loss=loss,
1298
+ logits=logits,
1299
+ hidden_states=model_outputs.hidden_states,
1300
+ attentions=model_outputs.attentions,
1301
+ )
1302
 
1303
 
1304
  @dataclass
 
1311
  image_features: Optional[torch.FloatTensor] = None
1312
 
1313
 
1314
+ class SiglipVisionEncoder(nn.Module):
1315
+ def __init__(self, config: LlavaConfig):
1316
+ super().__init__()
1317
+ self.vision_tower = SiglipVisionModel(config.vision_config)
1318
+
1319
+ self.coord_embed = nn.Sequential(
1320
+ nn.Linear(2, config.vision_embed_dim),
1321
+ nn.GELU(),
1322
+ nn.Linear(config.vision_embed_dim, config.vision_embed_dim),
1323
+ )
1324
+
1325
+ self.num_tokens = 728
1326
+
1327
+ def feature_select(self, image_forward_outs, coord_feature, num_tokens = None):
1328
+ image_features = image_forward_outs
1329
+ image_features = image_features[:, 1:]
1330
+ if num_tokens is None:
1331
+ num_tokens = self.num_tokens
1332
+ split_size = int(num_tokens / image_features.shape[0])
1333
+ sum = 0
1334
+ output_list = []
1335
+ for i in range(image_features.shape[0]):
1336
+ if i == image_features.shape[0] - 1:
1337
+ size = num_tokens - sum
1338
+ else:
1339
+ size = split_size
1340
+ sum += size
1341
+ chunk_output = image_features[i, -size:, :]
1342
+ chunk_output = chunk_output + coord_feature[i]
1343
+ output_list.append(chunk_output)
1344
+ image_features = torch.cat(output_list)
1345
+ return image_features
1346
+
1347
+ def process_image_chunks(self, image_tensor, coord_tensor, num_tokens = None):
1348
+ if image_tensor.shape[0] > 50:
1349
+ image_forward_out = []
1350
+ for i in range(0,image_tensor.shape[0],50):
1351
+ part_forward_out = self.vision_tower(image_tensor[i:i+50], output_hidden_states=True).hidden_states[-1]
1352
+ image_forward_out.append(part_forward_out)
1353
+ image_forward_out = torch.cat(image_forward_out, dim=0)
1354
+ else:
1355
+ image_forward_out = self.vision_tower(image_tensor, output_hidden_states=True).hidden_states[-1]
1356
+ coord_feature = self.coord_embed(coord_tensor)
1357
+ if len(coord_feature.shape) == 1:
1358
+ coord_feature = coord_feature.unsqueeze(0)
1359
+ image_feature = self.feature_select(image_forward_out, coord_feature, num_tokens).to(
1360
+ image_tensor.dtype
1361
+ )
1362
+ return image_feature
1363
+
1364
+ def forward(self, images: List[torch.Tensor], coords: List[torch.Tensor], num_tokens = None):
1365
+ image_features = []
1366
+ for i, image in enumerate(images):
1367
+ image_feature = self.process_image_chunks(image, coords[i], num_tokens)
1368
+ image_features.append(image_feature)
1369
+ image_features = torch.stack(image_features)
1370
+ return image_features
1371
+
1372
+
1373
  class LlavaMultiModalProjector(nn.Module):
1374
  def __init__(self, config: LlavaConfig):
1375
  super().__init__()
1376
 
1377
  self.linear_1 = nn.Linear(
1378
  config.vision_embed_dim,
1379
+ config.text_config.hidden_size,
1380
  bias=True,
1381
  )
1382
  self.act = nn.GELU()
1383
  self.linear_2 = nn.Linear(
1384
+ config.text_config.hidden_size,
1385
+ config.text_config.hidden_size,
1386
  bias=True,
1387
  )
 
1388
 
1389
  def forward(self, image_features):
1390
  hidden_states = self.linear_1(image_features)
 
1416
  return self.language_model._supports_sdpa
1417
 
1418
 
1419
+ class LlavaForCausalLM(LlavaPreTrainedModel):
1420
  def __init__(self, config: LlavaConfig):
1421
  super().__init__(config)
1422
+ self.vision_model = SiglipVisionEncoder(config)
 
1423
 
1424
  self.multi_modal_projector = LlavaMultiModalProjector(config)
1425
  self.vocab_size = config.vocab_size
 
1605
 
1606
  logits = outputs[0]
1607
 
 
1608
  if not return_dict:
1609
  output = (logits,) + outputs[1:]
1610
  return output
 
1626
  image_features=None,
1627
  **kwargs,
1628
  ):
1629
+ res = self.language_model.prepare_inputs_for_generation(
1630
+ input_ids, past_key_values, attention_mask, **kwargs
1631
+ )
1632
  input_ids = res["input_ids"]
1633
  past_key_values = res["past_key_values"]
1634
  attention_mask = res["attention_mask"]
preprocessor_config.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "auto_map": {
3
+ "AutoProcessor": "processing_llava.LlavaProcessor",
4
+ "AutoImageProcessor": "processing_llava.MultiCropImageProcessor"
5
+ },
6
+ "do_normalize": true,
7
+ "do_rescale": true,
8
+ "do_resize": true,
9
+ "image_mean": [
10
+ 0.5,
11
+ 0.5,
12
+ 0.5
13
+ ],
14
+ "image_std": [
15
+ 0.5,
16
+ 0.5,
17
+ 0.5
18
+ ],
19
+ "resample": 3,
20
+ "rescale_factor": 0.00392156862745098,
21
+ "size": {
22
+ "height": 384,
23
+ "width": 384
24
+ }
25
+ }
processing_llava.py CHANGED
@@ -1,24 +1,10 @@
1
- # coding=utf-8
2
- # Copyright 2023 The HuggingFace Inc. team.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
- """
16
- Processor class for Llava.
17
- """
18
-
19
-
20
  from typing import List, Optional, Union
21
 
 
 
 
 
22
  from transformers.feature_extraction_utils import BatchFeature
23
  from transformers.image_utils import ImageInput
24
  from transformers.tokenization_utils_base import (
@@ -28,52 +14,73 @@ from transformers.tokenization_utils_base import (
28
  TruncationStrategy,
29
  )
30
  from transformers.utils import TensorType
31
- import torch
32
- from open_clip.transform import PreprocessCfg, image_transform_v2
33
- from modeling_llava import LlavaForConditionalGeneration
34
- from PIL import Image
35
- import math
36
 
37
 
38
- class OpenCLIPImageProcessor:
39
- def __init__(self, config, crop_size=384, max_tokens=100):
40
- cfg = PreprocessCfg(**config)
41
- transform = image_transform_v2(cfg=cfg, is_train=False)
42
- self.transform = transform
43
- self.crop_size = crop_size
44
- self.max_tokens = max_tokens
45
 
46
- def __call__(self, image: Image.Image):
47
- output = self.transform_func(image)
48
- return {
49
- "pixel_values": output,
 
 
 
 
50
  }
 
 
 
 
 
 
 
51
 
52
- def transform_func(self, image: Image.Image):
 
 
 
 
53
  outputs = []
54
- outputs.append(self.transform(image))
 
 
55
  width, height = image.size
56
  crop_size = self.crop_size
57
- if width <= crop_size and height <= crop_size:
58
- outputs = torch.stack(outputs, dim=0)
59
- return outputs
 
 
 
 
 
 
60
  total_tokens = math.inf
61
- while total_tokens > self.max_tokens:
62
- total_tokens = math.floor(
63
- (2 * width - crop_size)
64
- / crop_size
65
- * (2 * height - crop_size)
66
- / crop_size
67
  )
68
- if total_tokens > self.max_tokens:
69
  crop_size += 10
70
- stride = crop_size // 2
71
- x_steps = int(round((2 * width - crop_size) / crop_size))
 
72
  if x_steps < 1:
73
  x_steps = 1
74
- y_steps = int(round((2 * height - crop_size) / crop_size))
75
  if y_steps < 1:
76
  y_steps = 1
 
 
 
 
77
  x_coords = []
78
  y_coords = []
79
  for i in range(x_steps):
@@ -85,6 +92,7 @@ class OpenCLIPImageProcessor:
85
  if y_coords[-1][1] != height:
86
  y_coords[-1][1] = height
87
  image_parts = []
 
88
  for i in range(len(x_coords)):
89
  for j in range(len(y_coords)):
90
  image_parts.append(
@@ -92,20 +100,38 @@ class OpenCLIPImageProcessor:
92
  (x_coords[i][0], y_coords[j][0], x_coords[i][1], y_coords[j][1])
93
  )
94
  )
 
 
 
 
 
 
 
 
95
  for image_part in image_parts:
96
- outputs.append(self.transform(image_part))
97
- outputs = torch.stack(outputs, dim=0)
98
- return outputs
 
 
 
99
 
100
- @property
101
- def model_input_names(self):
102
- return ["pixel_values"]
103
 
 
 
 
 
104
 
105
- class LlavaProcessor:
106
- def __init__(self, image_processor: OpenCLIPImageProcessor, tokenizer):
107
  self.image_processor = image_processor
108
  self.tokenizer = tokenizer
 
 
 
 
 
 
 
109
 
110
  def __call__(
111
  self,
@@ -113,20 +139,24 @@ class LlavaProcessor:
113
  TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]
114
  ] = None,
115
  images: ImageInput = None,
116
- model: LlavaForConditionalGeneration = None,
 
 
117
  padding: Union[bool, str, PaddingStrategy] = False,
118
  truncation: Union[bool, str, TruncationStrategy] = None,
119
  max_length=None,
120
  return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH,
121
  ) -> BatchFeature:
122
  if images is not None:
123
- pixel_values = self.image_processor(images)[
124
- "pixel_values"
 
 
125
  ]
126
- pixel_values = pixel_values.to(model.device).to(model.dtype)
127
- image_outputs = model.vision_model(pixel_values)
 
128
  image_features = model.multi_modal_projector(image_outputs)
129
- image_features = image_features.unsqueeze(0)
130
  else:
131
  image_features = None
132
  text_inputs = self.tokenizer(
@@ -136,7 +166,8 @@ class LlavaProcessor:
136
  truncation=truncation,
137
  max_length=max_length,
138
  )
139
-
 
140
  return BatchFeature(data={**text_inputs, "image_features": image_features})
141
 
142
  def batch_decode(self, *args, **kwargs):
 
1
+ import math
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  from typing import List, Optional, Union
3
 
4
+ import torch
5
+ from modeling_llava import LlavaForCausalLM
6
+ from PIL import Image
7
+ from transformers import ImageProcessingMixin, ProcessorMixin, SiglipImageProcessor, AutoTokenizer, AutoImageProcessor
8
  from transformers.feature_extraction_utils import BatchFeature
9
  from transformers.image_utils import ImageInput
10
  from transformers.tokenization_utils_base import (
 
14
  TruncationStrategy,
15
  )
16
  from transformers.utils import TensorType
 
 
 
 
 
17
 
18
 
19
+ class MultiCropImageProcessor(ImageProcessingMixin):
20
+ def __init__(self, model_name, max_crops=0, **kwargs):
21
+ self.processor = SiglipImageProcessor.from_pretrained(model_name)
22
+ self.crop_size = 384
23
+ self.max_crops = max_crops
24
+ self.stride_ratio = 2
 
25
 
26
+ def __call__(
27
+ self,
28
+ images: List[Image.Image],
29
+ max_crops: int = -1,
30
+ ):
31
+ res = {
32
+ "pixel_values": [],
33
+ "coords": [],
34
  }
35
+ if max_crops < 0:
36
+ max_crops = self.max_crops
37
+ for image in images:
38
+ outputs, output_coords = self.process_image(image, max_crops)
39
+ res["pixel_values"].append(outputs)
40
+ res["coords"].append(output_coords)
41
+ return res
42
 
43
+ def process_image(
44
+ self,
45
+ image: Image.Image,
46
+ max_crops: int
47
+ ):
48
  outputs = []
49
+ output_coords = []
50
+ outputs.append(self.processor(image, return_tensors="pt").pixel_values)
51
+ output_coords.append(torch.tensor([0.5, 0.5]))
52
  width, height = image.size
53
  crop_size = self.crop_size
54
+ stride = crop_size // self.stride_ratio
55
+ if (
56
+ max_crops == 0
57
+ or width <= (crop_size + stride)
58
+ and height <= (crop_size + stride)
59
+ ):
60
+ outputs = torch.cat(outputs, dim=0)
61
+ output_coords = torch.cat(output_coords, dim=0)
62
+ return outputs, output_coords
63
  total_tokens = math.inf
64
+ while total_tokens > max_crops:
65
+ total_tokens = (
66
+ math.floor((width - crop_size) / stride) + 1
67
+ ) * (
68
+ math.floor((height - crop_size) / stride) + 1
 
69
  )
70
+ if total_tokens > max_crops:
71
  crop_size += 10
72
+ stride = crop_size // self.stride_ratio
73
+ stride = crop_size // self.stride_ratio
74
+ x_steps = int(math.floor((width - crop_size) / stride) + 1)
75
  if x_steps < 1:
76
  x_steps = 1
77
+ y_steps = int(math.floor((height - crop_size) / stride) + 1)
78
  if y_steps < 1:
79
  y_steps = 1
80
+ if x_steps == 1 and y_steps == 1:
81
+ outputs = torch.cat(outputs, dim=0)
82
+ output_coords = torch.cat(output_coords, dim=0)
83
+ return outputs, output_coords
84
  x_coords = []
85
  y_coords = []
86
  for i in range(x_steps):
 
92
  if y_coords[-1][1] != height:
93
  y_coords[-1][1] = height
94
  image_parts = []
95
+ part_coords = []
96
  for i in range(len(x_coords)):
97
  for j in range(len(y_coords)):
98
  image_parts.append(
 
100
  (x_coords[i][0], y_coords[j][0], x_coords[i][1], y_coords[j][1])
101
  )
102
  )
103
+ part_coords.append(
104
+ torch.tensor(
105
+ [
106
+ (x_coords[i][0] + x_coords[i][1]) / 2 / width,
107
+ (y_coords[j][0] + y_coords[j][1]) / 2 / height,
108
+ ]
109
+ )
110
+ )
111
  for image_part in image_parts:
112
+ outputs.append(self.processor(image_part, return_tensors="pt").pixel_values)
113
+ for part_coord in part_coords:
114
+ output_coords.append(part_coord)
115
+ outputs = torch.cat(outputs, dim=0)
116
+ output_coords = torch.stack(output_coords, dim=0)
117
+ return outputs, output_coords
118
 
 
 
 
119
 
120
+ class LlavaProcessor(ProcessorMixin):
121
+ attributes = ["image_processor", "tokenizer"]
122
+ image_processor_class = MultiCropImageProcessor
123
+ tokenizer_class = "SiglipTokenizer"
124
 
125
+ def __init__(self, image_processor: MultiCropImageProcessor, tokenizer):
 
126
  self.image_processor = image_processor
127
  self.tokenizer = tokenizer
128
+ self.search_model = None
129
+
130
+ @classmethod
131
+ def from_pretrained(cls, path, trust_remote_code=True, **kwargs):
132
+ tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=trust_remote_code)
133
+ image_processor = MultiCropImageProcessor(path, trust_remote_code=trust_remote_code)
134
+ return LlavaProcessor(image_processor, tokenizer)
135
 
136
  def __call__(
137
  self,
 
139
  TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]
140
  ] = None,
141
  images: ImageInput = None,
142
+ model: LlavaForCausalLM = None,
143
+ max_crops: int = 0,
144
+ num_tokens = None,
145
  padding: Union[bool, str, PaddingStrategy] = False,
146
  truncation: Union[bool, str, TruncationStrategy] = None,
147
  max_length=None,
148
  return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH,
149
  ) -> BatchFeature:
150
  if images is not None:
151
+ processor_outputs = self.image_processor(images, max_crops)
152
+ pixel_values = processor_outputs["pixel_values"]
153
+ pixel_values = [
154
+ value.to(model.device).to(model.dtype) for value in pixel_values
155
  ]
156
+ coords = processor_outputs["coords"]
157
+ coords = [value.to(model.device).to(model.dtype) for value in coords]
158
+ image_outputs = model.vision_model(pixel_values, coords, num_tokens)
159
  image_features = model.multi_modal_projector(image_outputs)
 
160
  else:
161
  image_features = None
162
  text_inputs = self.tokenizer(
 
166
  truncation=truncation,
167
  max_length=max_length,
168
  )
169
+ text_inputs['input_ids'] = text_inputs['input_ids'].to(model.device)
170
+ text_inputs['attention_mask'] = text_inputs['attention_mask'].to(model.device)
171
  return BatchFeature(data={**text_inputs, "image_features": image_features})
172
 
173
  def batch_decode(self, *args, **kwargs):
tokenizer_config.json CHANGED
@@ -347,10 +347,9 @@
347
  }
348
  },
349
  "bos_token": "<|endoftext|>",
350
- "chat_template": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
351
  "clean_up_tokenization_spaces": true,
352
  "eos_token": "<|im_end|>",
353
- "model_max_length": 1200,
354
  "pad_token": "<pad>",
355
  "tokenizer_class": "CodeGenTokenizer",
356
  "unk_token": "<|endoftext|>"
 
347
  }
348
  },
349
  "bos_token": "<|endoftext|>",
 
350
  "clean_up_tokenization_spaces": true,
351
  "eos_token": "<|im_end|>",
352
+ "model_max_length": 2048,
353
  "pad_token": "<pad>",
354
  "tokenizer_class": "CodeGenTokenizer",
355
  "unk_token": "<|endoftext|>"