BAAI
/

ryanzhangfan commited on
Commit
94953a3
1 Parent(s): 748b205

Upload 30 files

Browse files
feature_extractor/preprocessor_config.json ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "crop_size": {
3
+ "height": 224,
4
+ "width": 224
5
+ },
6
+ "do_center_crop": true,
7
+ "do_convert_rgb": true,
8
+ "do_normalize": true,
9
+ "do_rescale": true,
10
+ "do_resize": true,
11
+ "feature_extractor_type": "CLIPFeatureExtractor",
12
+ "image_mean": [
13
+ 0.48145466,
14
+ 0.4578275,
15
+ 0.40821073
16
+ ],
17
+ "image_processor_type": "CLIPImageProcessor",
18
+ "image_std": [
19
+ 0.26862954,
20
+ 0.26130258,
21
+ 0.27577711
22
+ ],
23
+ "resample": 3,
24
+ "rescale_factor": 0.00392156862745098,
25
+ "size": {
26
+ "shortest_edge": 224
27
+ }
28
+ }
model_index.json ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "EmuVisualGenerationPipeline",
3
+ "_diffusers_version": "0.21.2",
4
+ "feature_extractor": [
5
+ "transformers",
6
+ "CLIPImageProcessor"
7
+ ],
8
+ "multimodal_encoder": [
9
+ "transformers_modules.modeling_emu",
10
+ "EmuForCausalLM"
11
+ ],
12
+ "safety_checker": [
13
+ "stable_diffusion",
14
+ "StableDiffusionSafetyChecker"
15
+ ],
16
+ "scheduler": [
17
+ "diffusers",
18
+ "EulerDiscreteScheduler"
19
+ ],
20
+ "tokenizer": [
21
+ "transformers",
22
+ "LlamaTokenizerFast"
23
+ ],
24
+ "unet": [
25
+ "diffusers",
26
+ "UNet2DConditionModel"
27
+ ],
28
+ "vae": [
29
+ "diffusers",
30
+ "AutoencoderKL"
31
+ ]
32
+ }
multimodal_encoder/config.json ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "/share/project/quansun/release_hf/Emu2-VisualGeneration/multimodal_encoder/",
3
+ "architectures": [
4
+ "EmuForCausalLM"
5
+ ],
6
+ "attention_bias": false,
7
+ "attention_dropout": 0.0,
8
+ "auto_map": {
9
+ "AutoConfig": "configuration_emu.EmuConfig",
10
+ "AutoModelForCausalLM": "modeling_emu.EmuForCausalLM"
11
+ },
12
+ "bos_token_id": 1,
13
+ "d_model": 1792,
14
+ "eos_token_id": 2,
15
+ "hidden_act": "silu",
16
+ "hidden_size": 6656,
17
+ "initializer_range": 0.02,
18
+ "intermediate_size": 17920,
19
+ "max_position_embeddings": 2048,
20
+ "model_version": "base",
21
+ "num_attention_heads": 52,
22
+ "num_hidden_layers": 60,
23
+ "num_key_value_heads": 52,
24
+ "pad_token_id": 32000,
25
+ "pretraining_tp": 1,
26
+ "rms_norm_eps": 1e-06,
27
+ "rope_scaling": null,
28
+ "rope_theta": 10000.0,
29
+ "tie_word_embeddings": false,
30
+ "torch_dtype": "bfloat16",
31
+ "transformers_version": "4.31.0",
32
+ "use_cache": true,
33
+ "vision_config": {
34
+ "drop_path_rate": 0,
35
+ "eva_model_name": "eva-clip-E-14-plus",
36
+ "head_width": 112,
37
+ "image_size": 448,
38
+ "intermediate_size": 15360,
39
+ "layer_norm_eps": 1e-06,
40
+ "layers": 64,
41
+ "mlp_ratio": 8.571428571428571,
42
+ "n_query": 64,
43
+ "patch_size": 14,
44
+ "postnorm": true,
45
+ "qkv_bias": true,
46
+ "v_query": 64,
47
+ "width": 1792,
48
+ "xattn": true
49
+ },
50
+ "vocab_size": 32272
51
+ }
multimodal_encoder/configuration_emu.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Literal
2
+ from transformers import PretrainedConfig
3
+
4
+
5
+ class EmuConfig(PretrainedConfig):
6
+ _auto_class = "AutoConfig"
7
+
8
+ def __init__(
9
+ self,
10
+ vocab_size=32000,
11
+ hidden_size=4096,
12
+ intermediate_size=11008,
13
+ num_hidden_layers=32,
14
+ num_attention_heads=32,
15
+ hidden_act='silu',
16
+ max_position_embeddings=2048,
17
+ initializer_range=0.02,
18
+ rms_norm_eps=1e-06,
19
+ model_version: Literal["base", "chat"] = "base",
20
+ pad_token_id=0,
21
+ bos_token_id=1,
22
+ eos_token_id=2,
23
+ tie_word_embeddings=False,
24
+ use_cache=True,
25
+ pretraining_tp=1,
26
+ rope_theta=10000.0,
27
+ rope_scaling=None,
28
+ attention_bias=False,
29
+ attention_dropout=0.0,
30
+ **kwargs,
31
+ ):
32
+ self.hidden_size = hidden_size
33
+ self.intermediate_size = intermediate_size
34
+ self.num_attention_heads = num_attention_heads
35
+ self.max_position_embeddings = max_position_embeddings
36
+ self.rms_norm_eps = rms_norm_eps
37
+ self.initializer_range = initializer_range
38
+ self.vocab_size = vocab_size
39
+ self.num_hidden_layers = num_hidden_layers
40
+ self.hidden_act = hidden_act
41
+ self.model_version = model_version
42
+ self.use_cache = use_cache
43
+ self.pretraining_tp = pretraining_tp
44
+ self.use_cache = use_cache
45
+ self.rope_theta = rope_theta
46
+ self.rope_scaling = rope_scaling
47
+ self._rope_scaling_validation()
48
+ self.attention_bias = attention_bias
49
+ self.attention_dropout = attention_dropout
50
+ super().__init__(
51
+ pad_token_id=pad_token_id,
52
+ bos_token_id=bos_token_id,
53
+ eos_token_id=eos_token_id,
54
+ tie_word_embeddings=tie_word_embeddings,
55
+ **kwargs,
56
+ )
57
+
58
+ def _rope_scaling_validation(self):
59
+ """
60
+ Validate the `rope_scaling` configuration.
61
+ """
62
+ if self.rope_scaling is None:
63
+ return
64
+
65
+ if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
66
+ raise ValueError(
67
+ "`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, "
68
+ f"got {self.rope_scaling}"
69
+ )
70
+ rope_scaling_type = self.rope_scaling.get("type", None)
71
+ rope_scaling_factor = self.rope_scaling.get("factor", None)
72
+ if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
73
+ raise ValueError(
74
+ f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
75
+ )
76
+ if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
77
+ raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}")
multimodal_encoder/constants.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ EVA_IMAGE_SIZE = 448
2
+ OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073)
3
+ OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711)
4
+
5
+ DEFAULT_IMAGE_FILE_SUFFIX = ['jpg', '0.png', 'png', 'jpeg', 'webp']
6
+ DEFAULT_TEXT_FILE_SUFFIX = ['txt', '0.txt']
7
+
8
+ IGNORE_INDEX = -100
9
+
10
+ # special tokens
11
+ # START
12
+ DEFAULT_PAD_TOKEN = "[PAD]"
13
+ DEFAULT_BOS_TOKEN = '<s>'
14
+ DEFAULT_EOS_TOKEN = '</s>'
15
+ DEFAULT_UNK_TOKEN = "<unk>"
16
+
17
+ DEFAULT_IMG_TOKEN = "[IMG]"
18
+ DEFAULT_IMG_END_TOKEN = "[/IMG]"
19
+ DEFAULT_IMAGE_TOKEN = "<image>"
20
+ DEFAULT_gIMG_TOKEN = "[gIMG]"
21
+ DEFAULT_gIMG_END_TOKEN = "[/gIMG]"
22
+ DEFAULT_EOC_TOKEN = "[EOC]"
23
+ DEFAULT_VIDEO_TOKEN = "[VIDEO]"
24
+
25
+ GRD_SYMBOL = "<grounding>"
26
+ BOP_SYMBOL = "<phrase>"
27
+ EOP_SYMBOL = "</phrase>"
28
+ BOO_SYMBOL = "<object>"
29
+ EOO_SYMBOL = "</object>"
30
+ DOM_SYMBOL = "</delimiter_of_multi_objects/>"
31
+
32
+ REC_SYMBOL = "<REC>"
33
+
34
+ USER_TOKEN = "[USER]"
35
+ ASSISTANT_TOKEN = "[ASSISTANT]"
36
+ # END
37
+
38
+ # special token id
39
+ # START
40
+ IMAGE = 32003
41
+ BOI = 32001
42
+ VIDEO = 32004
43
+ # END
44
+
45
+ DEFAULT_IMG_PLACEHOLDER = "[<IMG_PLH>]"
46
+ DEFAULT_VID_PLACEHOLDER = "[<VID_PLH>]"
47
+ FAKE_VIDEO_END_TOKEN = "[/VIDEO]"
multimodal_encoder/model.bf16-00001-of-00008.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:849f23e3d375518a179cb7887cb8861f088e185e7619e518a38ec2a069417f87
3
+ size 9961629600
multimodal_encoder/model.bf16-00002-of-00008.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ae62cc224559ee79ccc91687e3457310f3797f7517df944d02af637cad666cf4
3
+ size 9958082896
multimodal_encoder/model.bf16-00003-of-00008.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3690630dfd3ad092a527fbd5a00bc3881c6e1ff4cedf8c46001eec8a47c1e9f3
3
+ size 9896714920
multimodal_encoder/model.bf16-00004-of-00008.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d9b92e277b4a31bf1daaea769b8702f32ea0cf61657f1d0f64305fe0b8ed266a
3
+ size 9869451296
multimodal_encoder/model.bf16-00005-of-00008.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:adba114c2f977df27e344297798cce0fae6537891339e3aa030764d892004aa1
3
+ size 9869451296
multimodal_encoder/model.bf16-00006-of-00008.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ed94b7b7fdfe014355af7b0eb99be16bf5b0e0d384cd07c358bbc078fb1d2c22
3
+ size 9958082992
multimodal_encoder/model.bf16-00007-of-00008.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f89cfc60475e3454e315fa73fb4afc263e89d87a2c93377411156c5462346590
3
+ size 9896714920
multimodal_encoder/model.bf16-00008-of-00008.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:332b756156697afb8614b55baee954df24db366d80603e6dc83e6d3b1d5e0e4d
3
+ size 4403309264
multimodal_encoder/model.safetensors.index.bf16.json ADDED
The diff for this file is too large to render. See raw diff
 
multimodal_encoder/modeling_emu.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+ from typing import List, Optional
3
+ from argparse import Namespace
4
+ import torch
5
+ from torch import nn
6
+ import torch.nn.functional as F
7
+ from transformers import PreTrainedModel, PreTrainedTokenizer
8
+
9
+ from .configuration_emu import EmuConfig
10
+ from .constants import *
11
+ from .modeling_llama import LlamaForCausalLM
12
+ from .visual import EVAVisionTransformer
13
+
14
+
15
+ class EmuPreTrainedModel(PreTrainedModel):
16
+ config_class = EmuConfig
17
+ base_model_prefix = "model"
18
+ supports_gradient_checkpointing = False
19
+ _no_split_modules = ["LlamaDecoderLayer", "Block"]
20
+ _skip_keys_device_placement = "past_key_values"
21
+
22
+ def _init_weights(self, module):
23
+ std = self.config.initializer_range
24
+ if isinstance(module, nn.Linear):
25
+ module.weight.data.normal_(mean=0.0, std=std)
26
+ if module.bias is not None:
27
+ module.bias.data.zero_()
28
+ elif isinstance(module, nn.Embedding):
29
+ module.weight.data.normal_(mean=0.0, std=std)
30
+ if module.padding_idx is not None:
31
+ module.weight.data[module.padding_idx].zero_()
32
+
33
+ class EmuForClsAndRegression(EmuPreTrainedModel):
34
+
35
+ def __init__(self, config):
36
+ super(EmuForClsAndRegression, self).__init__(config)
37
+
38
+ self.lm = LlamaForCausalLM(config=config)
39
+
40
+ self.lm.model.embed_tokens.padding_idx = config.pad_token_id
41
+
42
+ def get_num_layers(self):
43
+ return len(self.lm.model.layers)
44
+
45
+ class EmuModel(EmuPreTrainedModel):
46
+
47
+ def __init__(self, config):
48
+ super().__init__(config)
49
+
50
+ vision_config = Namespace(**config.vision_config)
51
+
52
+ self.visual = EVAVisionTransformer(
53
+ img_size=vision_config.image_size,
54
+ patch_size=vision_config.patch_size,
55
+ embed_dim=vision_config.width,
56
+ depth=vision_config.layers,
57
+ num_heads=vision_config.width // vision_config.head_width,
58
+ mlp_ratio=vision_config.mlp_ratio,
59
+ qkv_bias=vision_config.qkv_bias,
60
+ drop_path_rate=vision_config.drop_path_rate,
61
+ norm_layer=partial(nn.LayerNorm, eps=vision_config.layer_norm_eps),
62
+ xattn=vision_config.xattn,
63
+ postnorm=vision_config.postnorm,
64
+ )
65
+
66
+ self.decoder = EmuForClsAndRegression(config)
67
+
68
+ self.gradient_checkpointing = False
69
+
70
+ self.n_query = vision_config.n_query
71
+ self.v_query = vision_config.v_query
72
+
73
+ @property
74
+ def device(self):
75
+ return next(iter(self.parameters())).device
76
+
77
+ @property
78
+ def dtype(self):
79
+ return next(iter(self.parameters())).dtype
80
+
81
+ @torch.no_grad()
82
+ def encode_image(self, image: torch.Tensor, *, n_query=None):
83
+ n_query = n_query if n_query is not None else self.n_query
84
+
85
+ image_embeds = self.visual(image)
86
+ image_embeds = image_embeds[:, 1:, :]
87
+ b, n, c = image_embeds.shape
88
+ sqrt_n = int(n**0.5)
89
+ image_embeds = image_embeds.permute(0, 2, 1).view(b, c, sqrt_n, sqrt_n)
90
+
91
+ stride = int(sqrt_n // (n_query ** 0.5))
92
+ image_embeds = F.avg_pool2d(image_embeds, kernel_size=(stride, stride), stride=stride)
93
+ image_embeds = image_embeds.view(b, c, -1).permute(0, 2, 1).contiguous()
94
+ return image_embeds
95
+
96
+
97
+ class EmuForCausalLM(EmuPreTrainedModel):
98
+ _auto_class = "AutoModelForCausalLM"
99
+
100
+ def __init__(self, config):
101
+ super().__init__(config)
102
+
103
+ self.config = config
104
+ self.model = EmuModel(config)
105
+ # LM to EVA
106
+ self.project_down = nn.Linear(config.hidden_size, config.d_model, bias=False)
107
+ # EVA to LM
108
+ self.project_up = nn.Linear(config.d_model, config.hidden_size, bias=False)
109
+
110
+ self.n_query = self.model.n_query
111
+ self.image_placeholder = DEFAULT_IMG_TOKEN + DEFAULT_IMAGE_TOKEN * self.n_query + DEFAULT_IMG_END_TOKEN
112
+
113
+ def device(self, module=None):
114
+ if module is None:
115
+ return next(self.parameters()).device
116
+ return next(module.parameters()).device
117
+
118
+ def dtype(self, module):
119
+ if module is None:
120
+ return next(self.parameters()).dtype
121
+ return next(module.parameters()).dtype
122
+
123
+ @torch.no_grad()
124
+ def generate_image(
125
+ self,
126
+ text: List[str],
127
+ tokenizer: PreTrainedTokenizer,
128
+ image: Optional[torch.Tensor] = None,
129
+ placeholder: str = DEFAULT_IMG_PLACEHOLDER,
130
+ ):
131
+ IMAGE, BOI = tokenizer.convert_tokens_to_ids([DEFAULT_IMAGE_TOKEN, DEFAULT_IMG_TOKEN])
132
+ if image is not None:
133
+ prompt_image_embeds = self.model.encode_image(image)
134
+ _, _, c = prompt_image_embeds.shape
135
+ prompt_image_embeds = prompt_image_embeds.view(-1, c)
136
+ prompt_image_embeds = self.project_up(prompt_image_embeds)
137
+
138
+ text = [t.replace(placeholder, self.image_placeholder) for t in text]
139
+
140
+ target_image_embeds = None
141
+ for num_img_token in range(self.n_query):
142
+ if num_img_token == 0:
143
+ text = [f"{t}{DEFAULT_IMG_TOKEN}" for t in text]
144
+ else:
145
+ text = [f"{t}{DEFAULT_IMAGE_TOKEN}" for t in text]
146
+
147
+ inputs = tokenizer(text, padding="longest", return_tensors="pt")
148
+ device = self.device(self.model.decoder.lm.model.embed_tokens)
149
+ attention_mask = inputs.attention_mask.to(device)
150
+ input_ids = inputs.input_ids.to(device) # B x N
151
+
152
+ text_embeds = self.model.decoder.lm.model.embed_tokens(input_ids)
153
+
154
+ image_idx = (input_ids == IMAGE)
155
+ cumsum_idx = torch.flip(torch.cumsum(torch.flip(image_idx, dims=[1]), dim=1), dims=[1])
156
+ if image is not None:
157
+ prompt_idx = torch.logical_and(image_idx, cumsum_idx > num_img_token)
158
+ text_embeds[prompt_idx] = prompt_image_embeds.to(text_embeds.device)
159
+
160
+ if target_image_embeds is not None:
161
+ target_idx = torch.logical_and(image_idx, torch.logical_and(cumsum_idx > 0, cumsum_idx <= num_img_token))
162
+ text_embeds[target_idx] = self.project_up(target_image_embeds).to(text_embeds.device)
163
+
164
+ outputs = self.model.decoder.lm.model(
165
+ inputs_embeds=text_embeds,
166
+ attention_mask=attention_mask,
167
+ output_hidden_states=True,
168
+ return_dict=True,
169
+ )
170
+
171
+ image_idx = (input_ids == IMAGE) + (input_ids == BOI)
172
+ cumsum_idx = torch.flip(torch.cumsum(torch.flip(image_idx, dims=[1]), dim=1), dims=[1])
173
+ target_idx = torch.logical_and(image_idx, torch.logical_and(cumsum_idx > 0, cumsum_idx <= num_img_token+1))
174
+
175
+ hidden_states = outputs.hidden_states[-1]
176
+ target_image_embeds = hidden_states[target_idx.to(hidden_states.device)]
177
+ target_image_embeds = target_image_embeds.view(-1, target_image_embeds.shape[-1])
178
+ target_image_embeds = self.project_down(target_image_embeds)
179
+
180
+ _, C = target_image_embeds.shape
181
+ B = hidden_states.shape[0]
182
+ target_image_embeds = target_image_embeds.view(B, -1, C)
183
+
184
+ return target_image_embeds
185
+
multimodal_encoder/modeling_llama.py ADDED
@@ -0,0 +1,1011 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5
+ # and OPT implementations in this library. It has been modified from its
6
+ # original forms to accommodate minor architectural differences compared
7
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+ """ PyTorch LLaMA model."""
21
+ import math
22
+ from typing import List, Optional, Tuple, Union
23
+
24
+ import torch
25
+ import torch.nn.functional as F
26
+ import torch.utils.checkpoint
27
+ from torch import nn
28
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
29
+
30
+ from transformers import PreTrainedModel
31
+ from transformers import LlamaConfig
32
+ from transformers.activations import ACT2FN
33
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
34
+ from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
35
+
36
+
37
+
38
+ logger = logging.get_logger(__name__)
39
+
40
+ _CONFIG_FOR_DOC = "LlamaConfig"
41
+
42
+
43
+ # Copied from transformers.models.bart.modeling_bart._make_causal_mask
44
+ def _make_causal_mask(
45
+ input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
46
+ ):
47
+ """
48
+ Make causal mask used for bi-directional self-attention.
49
+ """
50
+ bsz, tgt_len = input_ids_shape
51
+ mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
52
+ mask_cond = torch.arange(mask.size(-1), device=device)
53
+ mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
54
+ mask = mask.to(dtype)
55
+
56
+ if past_key_values_length > 0:
57
+ mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
58
+ return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
59
+
60
+
61
+ # Copied from transformers.models.bart.modeling_bart._expand_mask
62
+ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
63
+ """
64
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
65
+ """
66
+ bsz, src_len = mask.size()
67
+ tgt_len = tgt_len if tgt_len is not None else src_len
68
+
69
+ expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
70
+
71
+ inverted_mask = 1.0 - expanded_mask
72
+
73
+ return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
74
+
75
+
76
+ class LlamaRMSNorm(nn.Module):
77
+ def __init__(self, hidden_size, eps=1e-6):
78
+ """
79
+ LlamaRMSNorm is equivalent to T5LayerNorm
80
+ """
81
+ super().__init__()
82
+ self.weight = nn.Parameter(torch.ones(hidden_size))
83
+ self.variance_epsilon = eps
84
+
85
+ def forward(self, hidden_states):
86
+ input_dtype = hidden_states.dtype
87
+ hidden_states = hidden_states.to(torch.float32)
88
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
89
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
90
+ return self.weight * hidden_states.to(input_dtype)
91
+
92
+
93
+ class LlamaRotaryEmbedding(torch.nn.Module):
94
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
95
+ super().__init__()
96
+
97
+ self.dim = dim
98
+ self.max_position_embeddings = max_position_embeddings
99
+ self.base = base
100
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
101
+ self.register_buffer("inv_freq", inv_freq)
102
+
103
+ # Build here to make `torch.jit.trace` work.
104
+ self._set_cos_sin_cache(
105
+ seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
106
+ )
107
+
108
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
109
+ self.max_seq_len_cached = seq_len
110
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
111
+
112
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
113
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
114
+ emb = torch.cat((freqs, freqs), dim=-1)
115
+ self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
116
+ self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
117
+
118
+ def forward(self, x, seq_len=None):
119
+ # x: [bs, num_attention_heads, seq_len, head_size]
120
+ if seq_len > self.max_seq_len_cached:
121
+ self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
122
+
123
+ return (
124
+ self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
125
+ self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
126
+ )
127
+
128
+
129
+ class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding):
130
+ """LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
131
+
132
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
133
+ self.scaling_factor = scaling_factor
134
+ super().__init__(dim, max_position_embeddings, base, device)
135
+
136
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
137
+ self.max_seq_len_cached = seq_len
138
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
139
+ t = t / self.scaling_factor
140
+
141
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
142
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
143
+ emb = torch.cat((freqs, freqs), dim=-1)
144
+ self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
145
+ self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
146
+
147
+
148
+ class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding):
149
+ """LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
150
+
151
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
152
+ self.scaling_factor = scaling_factor
153
+ super().__init__(dim, max_position_embeddings, base, device)
154
+
155
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
156
+ self.max_seq_len_cached = seq_len
157
+
158
+ if seq_len > self.max_position_embeddings:
159
+ base = self.base * (
160
+ (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
161
+ ) ** (self.dim / (self.dim - 2))
162
+ inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
163
+ self.register_buffer("inv_freq", inv_freq)
164
+
165
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
166
+
167
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
168
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
169
+ emb = torch.cat((freqs, freqs), dim=-1)
170
+ self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
171
+ self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
172
+
173
+
174
+ def rotate_half(x):
175
+ """Rotates half the hidden dims of the input."""
176
+ x1 = x[..., : x.shape[-1] // 2]
177
+ x2 = x[..., x.shape[-1] // 2 :]
178
+ return torch.cat((-x2, x1), dim=-1)
179
+
180
+
181
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
182
+
183
+ # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
184
+ cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
185
+ sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
186
+ cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
187
+ sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
188
+ q_embed = (q * cos) + (rotate_half(q) * sin)
189
+ k_embed = (k * cos) + (rotate_half(k) * sin)
190
+ return q_embed, k_embed
191
+
192
+
193
+ class LlamaMLP(nn.Module):
194
+ def __init__(self, config):
195
+ super().__init__()
196
+ self.pretraining_tp = config.pretraining_tp
197
+ self.hidden_size = config.hidden_size
198
+ self.intermediate_size = config.intermediate_size
199
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
200
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
201
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
202
+ self.act_fn = ACT2FN[config.hidden_act]
203
+
204
+ def forward(self, x):
205
+ if self.pretraining_tp > 1:
206
+ slice = self.intermediate_size // self.pretraining_tp
207
+ gate_proj_slices = self.gate_proj.weight.split(slice, dim=0)
208
+ up_proj_slices = self.up_proj.weight.split(slice, dim=0)
209
+ down_proj_slices = self.down_proj.weight.split(slice, dim=1)
210
+
211
+ gate_proj = torch.cat([F.linear(x, gate_proj_slices[i]) for i in range(self.pretraining_tp)], dim=-1)
212
+ up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.pretraining_tp)], dim=-1)
213
+
214
+ intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2)
215
+ down_proj = [F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.pretraining_tp)]
216
+ down_proj = sum(down_proj)
217
+ else:
218
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
219
+
220
+ return down_proj
221
+
222
+
223
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
224
+ """
225
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
226
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
227
+ """
228
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
229
+ if n_rep == 1:
230
+ return hidden_states
231
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
232
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
233
+
234
+
235
+ class LlamaAttention(nn.Module):
236
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
237
+
238
+ def __init__(self, config: LlamaConfig):
239
+ super().__init__()
240
+ self.config = config
241
+ self.hidden_size = config.hidden_size
242
+ self.num_heads = config.num_attention_heads
243
+ self.head_dim = self.hidden_size // self.num_heads
244
+ self.num_key_value_heads = config.num_key_value_heads
245
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
246
+ self.pretraining_tp = config.pretraining_tp
247
+ self.max_position_embeddings = config.max_position_embeddings
248
+
249
+ if (self.head_dim * self.num_heads) != self.hidden_size:
250
+ raise ValueError(
251
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
252
+ f" and `num_heads`: {self.num_heads})."
253
+ )
254
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
255
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
256
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
257
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
258
+ self._init_rope()
259
+
260
+ def _init_rope(self):
261
+ if self.config.rope_scaling is None:
262
+ self.rotary_emb = LlamaRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings)
263
+ else:
264
+ scaling_type = self.config.rope_scaling["type"]
265
+ scaling_factor = self.config.rope_scaling["factor"]
266
+ if scaling_type == "linear":
267
+ self.rotary_emb = LlamaLinearScalingRotaryEmbedding(
268
+ self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor
269
+ )
270
+ elif scaling_type == "dynamic":
271
+ self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding(
272
+ self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor
273
+ )
274
+ else:
275
+ raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
276
+
277
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
278
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
279
+
280
+ def forward(
281
+ self,
282
+ hidden_states: torch.Tensor,
283
+ attention_mask: Optional[torch.Tensor] = None,
284
+ position_ids: Optional[torch.LongTensor] = None,
285
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
286
+ output_attentions: bool = False,
287
+ use_cache: bool = False,
288
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
289
+ bsz, q_len, _ = hidden_states.size()
290
+
291
+ if self.pretraining_tp > 1:
292
+ key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.pretraining_tp
293
+ query_slices = self.q_proj.weight.split((self.num_heads * self.head_dim) // self.pretraining_tp, dim=0)
294
+ key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
295
+ value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
296
+
297
+ query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.pretraining_tp)]
298
+ query_states = torch.cat(query_states, dim=-1)
299
+
300
+ key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.pretraining_tp)]
301
+ key_states = torch.cat(key_states, dim=-1)
302
+
303
+ value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.pretraining_tp)]
304
+ value_states = torch.cat(value_states, dim=-1)
305
+
306
+ else:
307
+ query_states = self.q_proj(hidden_states)
308
+ key_states = self.k_proj(hidden_states)
309
+ value_states = self.v_proj(hidden_states)
310
+
311
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
312
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
313
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
314
+
315
+ kv_seq_len = key_states.shape[-2]
316
+ if past_key_value is not None:
317
+ kv_seq_len += past_key_value[0].shape[-2]
318
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
319
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
320
+
321
+ if past_key_value is not None:
322
+ # reuse k, v, self_attention
323
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
324
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
325
+
326
+ past_key_value = (key_states, value_states) if use_cache else None
327
+
328
+ # repeat k/v heads if n_kv_heads < n_heads
329
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
330
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
331
+
332
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
333
+
334
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
335
+ raise ValueError(
336
+ f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
337
+ f" {attn_weights.size()}"
338
+ )
339
+
340
+ if attention_mask is not None:
341
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
342
+ raise ValueError(
343
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
344
+ )
345
+ attn_weights = attn_weights + attention_mask
346
+
347
+ # upcast attention to fp32
348
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
349
+ attn_output = torch.matmul(attn_weights, value_states)
350
+
351
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
352
+ raise ValueError(
353
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
354
+ f" {attn_output.size()}"
355
+ )
356
+
357
+ attn_output = attn_output.transpose(1, 2).contiguous()
358
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
359
+
360
+ if self.pretraining_tp > 1:
361
+ attn_output = attn_output.split(self.hidden_size // self.pretraining_tp, dim=2)
362
+ o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.pretraining_tp, dim=1)
363
+ attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.pretraining_tp)])
364
+ else:
365
+ attn_output = self.o_proj(attn_output)
366
+
367
+ if not output_attentions:
368
+ attn_weights = None
369
+
370
+ return attn_output, attn_weights, past_key_value
371
+
372
+
373
+ class LlamaDecoderLayer(nn.Module):
374
+ def __init__(self, config: LlamaConfig):
375
+ super().__init__()
376
+ self.hidden_size = config.hidden_size
377
+ self.self_attn = LlamaAttention(config=config)
378
+ self.mlp = LlamaMLP(config)
379
+ self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
380
+ self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
381
+
382
+ def forward(
383
+ self,
384
+ hidden_states: torch.Tensor,
385
+ attention_mask: Optional[torch.Tensor] = None,
386
+ position_ids: Optional[torch.LongTensor] = None,
387
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
388
+ output_attentions: Optional[bool] = False,
389
+ use_cache: Optional[bool] = False,
390
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
391
+ """
392
+ Args:
393
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
394
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
395
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
396
+ output_attentions (`bool`, *optional*):
397
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
398
+ returned tensors for more detail.
399
+ use_cache (`bool`, *optional*):
400
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
401
+ (see `past_key_values`).
402
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
403
+ """
404
+
405
+ residual = hidden_states
406
+
407
+ hidden_states = self.input_layernorm(hidden_states)
408
+
409
+ # Self Attention
410
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
411
+ hidden_states=hidden_states,
412
+ attention_mask=attention_mask,
413
+ position_ids=position_ids,
414
+ past_key_value=past_key_value,
415
+ output_attentions=output_attentions,
416
+ use_cache=use_cache,
417
+ )
418
+ hidden_states = residual + hidden_states
419
+
420
+ # Fully Connected
421
+ residual = hidden_states
422
+ hidden_states = self.post_attention_layernorm(hidden_states)
423
+ hidden_states = self.mlp(hidden_states)
424
+ hidden_states = residual + hidden_states
425
+
426
+ outputs = (hidden_states,)
427
+
428
+ if output_attentions:
429
+ outputs += (self_attn_weights,)
430
+
431
+ if use_cache:
432
+ outputs += (present_key_value,)
433
+
434
+ return outputs
435
+
436
+
437
+ LLAMA_START_DOCSTRING = r"""
438
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
439
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
440
+ etc.)
441
+
442
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
443
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
444
+ and behavior.
445
+
446
+ Parameters:
447
+ config ([`LlamaConfig`]):
448
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
449
+ load the weights associated with the model, only the configuration. Check out the
450
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
451
+ """
452
+
453
+
454
+ @add_start_docstrings(
455
+ "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
456
+ LLAMA_START_DOCSTRING,
457
+ )
458
+ class LlamaPreTrainedModel(PreTrainedModel):
459
+ config_class = LlamaConfig
460
+ base_model_prefix = "model"
461
+ supports_gradient_checkpointing = True
462
+ _no_split_modules = ["LlamaDecoderLayer"]
463
+ _skip_keys_device_placement = "past_key_values"
464
+
465
+ def _init_weights(self, module):
466
+ std = self.config.initializer_range
467
+ if isinstance(module, nn.Linear):
468
+ module.weight.data.normal_(mean=0.0, std=std)
469
+ if module.bias is not None:
470
+ module.bias.data.zero_()
471
+ elif isinstance(module, nn.Embedding):
472
+ module.weight.data.normal_(mean=0.0, std=std)
473
+ if module.padding_idx is not None:
474
+ module.weight.data[module.padding_idx].zero_()
475
+
476
+ def _set_gradient_checkpointing(self, module, value=False):
477
+ if isinstance(module, LlamaModel):
478
+ module.gradient_checkpointing = value
479
+
480
+
481
+ LLAMA_INPUTS_DOCSTRING = r"""
482
+ Args:
483
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
484
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
485
+ it.
486
+
487
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
488
+ [`PreTrainedTokenizer.__call__`] for details.
489
+
490
+ [What are input IDs?](../glossary#input-ids)
491
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
492
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
493
+
494
+ - 1 for tokens that are **not masked**,
495
+ - 0 for tokens that are **masked**.
496
+
497
+ [What are attention masks?](../glossary#attention-mask)
498
+
499
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
500
+ [`PreTrainedTokenizer.__call__`] for details.
501
+
502
+ If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
503
+ `past_key_values`).
504
+
505
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
506
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
507
+ information on the default strategy.
508
+
509
+ - 1 indicates the head is **not masked**,
510
+ - 0 indicates the head is **masked**.
511
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
512
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
513
+ config.n_positions - 1]`.
514
+
515
+ [What are position IDs?](../glossary#position-ids)
516
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
517
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
518
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
519
+ `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
520
+
521
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
522
+ blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
523
+
524
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
525
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
526
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
527
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
528
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
529
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
530
+ model's internal embedding lookup matrix.
531
+ use_cache (`bool`, *optional*):
532
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
533
+ `past_key_values`).
534
+ output_attentions (`bool`, *optional*):
535
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
536
+ tensors for more detail.
537
+ output_hidden_states (`bool`, *optional*):
538
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
539
+ more detail.
540
+ return_dict (`bool`, *optional*):
541
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
542
+ """
543
+
544
+
545
+ @add_start_docstrings(
546
+ "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
547
+ LLAMA_START_DOCSTRING,
548
+ )
549
+ class LlamaModel(LlamaPreTrainedModel):
550
+ """
551
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
552
+
553
+ Args:
554
+ config: LlamaConfig
555
+ """
556
+
557
+ def __init__(self, config: LlamaConfig):
558
+ super().__init__(config)
559
+ self.padding_idx = config.pad_token_id
560
+ self.vocab_size = config.vocab_size
561
+
562
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
563
+ self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)])
564
+ self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
565
+
566
+ self.gradient_checkpointing = False
567
+ # Initialize weights and apply final processing
568
+ self.post_init()
569
+
570
+ def get_input_embeddings(self):
571
+ return self.embed_tokens
572
+
573
+ def set_input_embeddings(self, value):
574
+ self.embed_tokens = value
575
+
576
+ # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
577
+ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
578
+ # create causal mask
579
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
580
+ combined_attention_mask = None
581
+ if input_shape[-1] > 1:
582
+ combined_attention_mask = _make_causal_mask(
583
+ input_shape,
584
+ inputs_embeds.dtype,
585
+ device=inputs_embeds.device,
586
+ past_key_values_length=past_key_values_length,
587
+ )
588
+
589
+ if attention_mask is not None:
590
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
591
+ expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
592
+ inputs_embeds.device
593
+ )
594
+ combined_attention_mask = (
595
+ expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
596
+ )
597
+
598
+ return combined_attention_mask
599
+
600
+ @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
601
+ def forward(
602
+ self,
603
+ input_ids: torch.LongTensor = None,
604
+ attention_mask: Optional[torch.Tensor] = None,
605
+ position_ids: Optional[torch.LongTensor] = None,
606
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
607
+ inputs_embeds: Optional[torch.FloatTensor] = None,
608
+ use_cache: Optional[bool] = None,
609
+ output_attentions: Optional[bool] = None,
610
+ output_hidden_states: Optional[bool] = None,
611
+ return_dict: Optional[bool] = None,
612
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
613
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
614
+ output_hidden_states = (
615
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
616
+ )
617
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
618
+
619
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
620
+
621
+ # retrieve input_ids and inputs_embeds
622
+ if input_ids is not None and inputs_embeds is not None:
623
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
624
+ elif input_ids is not None:
625
+ batch_size, seq_length = input_ids.shape
626
+ elif inputs_embeds is not None:
627
+ batch_size, seq_length, _ = inputs_embeds.shape
628
+ else:
629
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
630
+
631
+ seq_length_with_past = seq_length
632
+ past_key_values_length = 0
633
+
634
+ if past_key_values is not None:
635
+ past_key_values_length = past_key_values[0][0].shape[2]
636
+ seq_length_with_past = seq_length_with_past + past_key_values_length
637
+
638
+ if position_ids is None:
639
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
640
+ position_ids = torch.arange(
641
+ past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
642
+ )
643
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
644
+ else:
645
+ position_ids = position_ids.view(-1, seq_length).long()
646
+
647
+ if inputs_embeds is None:
648
+ inputs_embeds = self.embed_tokens(input_ids)
649
+ # embed positions
650
+ if attention_mask is None:
651
+ attention_mask = torch.ones(
652
+ (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
653
+ )
654
+ attention_mask = self._prepare_decoder_attention_mask(
655
+ attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
656
+ )
657
+
658
+ hidden_states = inputs_embeds
659
+
660
+ if self.gradient_checkpointing and self.training:
661
+ if use_cache:
662
+ logger.warning_once(
663
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
664
+ )
665
+ use_cache = False
666
+
667
+ # decoder layers
668
+ all_hidden_states = () if output_hidden_states else None
669
+ all_self_attns = () if output_attentions else None
670
+ next_decoder_cache = () if use_cache else None
671
+
672
+ for idx, decoder_layer in enumerate(self.layers):
673
+ if output_hidden_states:
674
+ all_hidden_states += (hidden_states,)
675
+
676
+ past_key_value = past_key_values[idx] if past_key_values is not None else None
677
+
678
+ if self.gradient_checkpointing and self.training:
679
+
680
+ def create_custom_forward(module):
681
+ def custom_forward(*inputs):
682
+ # None for past_key_value
683
+ return module(*inputs, output_attentions, None)
684
+
685
+ return custom_forward
686
+
687
+ layer_outputs = torch.utils.checkpoint.checkpoint(
688
+ create_custom_forward(decoder_layer),
689
+ hidden_states,
690
+ attention_mask,
691
+ position_ids,
692
+ None,
693
+ )
694
+ else:
695
+ layer_outputs = decoder_layer(
696
+ hidden_states,
697
+ attention_mask=attention_mask,
698
+ position_ids=position_ids,
699
+ past_key_value=past_key_value,
700
+ output_attentions=output_attentions,
701
+ use_cache=use_cache,
702
+ )
703
+
704
+ hidden_states = layer_outputs[0]
705
+
706
+ if use_cache:
707
+ next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
708
+
709
+ if output_attentions:
710
+ all_self_attns += (layer_outputs[1],)
711
+
712
+ hidden_states = self.norm(hidden_states)
713
+
714
+ # add hidden states from the last decoder layer
715
+ if output_hidden_states:
716
+ all_hidden_states += (hidden_states,)
717
+
718
+ next_cache = next_decoder_cache if use_cache else None
719
+ if not return_dict:
720
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
721
+ return BaseModelOutputWithPast(
722
+ last_hidden_state=hidden_states,
723
+ past_key_values=next_cache,
724
+ hidden_states=all_hidden_states,
725
+ attentions=all_self_attns,
726
+ )
727
+
728
+
729
+ class LlamaForCausalLM(LlamaPreTrainedModel):
730
+ _tied_weights_keys = ["lm_head.weight"]
731
+
732
+ def __init__(self, config):
733
+ super().__init__(config)
734
+ self.model = LlamaModel(config)
735
+ self.pretraining_tp = config.pretraining_tp
736
+ self.vocab_size = config.vocab_size
737
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
738
+
739
+ # Initialize weights and apply final processing
740
+ self.post_init()
741
+
742
+ def get_input_embeddings(self):
743
+ return self.model.embed_tokens
744
+
745
+ def set_input_embeddings(self, value):
746
+ self.model.embed_tokens = value
747
+
748
+ def get_output_embeddings(self):
749
+ return self.lm_head
750
+
751
+ def set_output_embeddings(self, new_embeddings):
752
+ self.lm_head = new_embeddings
753
+
754
+ def set_decoder(self, decoder):
755
+ self.model = decoder
756
+
757
+ def get_decoder(self):
758
+ return self.model
759
+
760
+ @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
761
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
762
+ def forward(
763
+ self,
764
+ input_ids: torch.LongTensor = None,
765
+ attention_mask: Optional[torch.Tensor] = None,
766
+ position_ids: Optional[torch.LongTensor] = None,
767
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
768
+ inputs_embeds: Optional[torch.FloatTensor] = None,
769
+ labels: Optional[torch.LongTensor] = None,
770
+ use_cache: Optional[bool] = None,
771
+ output_attentions: Optional[bool] = None,
772
+ output_hidden_states: Optional[bool] = None,
773
+ return_dict: Optional[bool] = None,
774
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
775
+ r"""
776
+ Args:
777
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
778
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
779
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
780
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
781
+
782
+ Returns:
783
+
784
+ Example:
785
+
786
+ ```python
787
+ >>> from transformers import AutoTokenizer, LlamaForCausalLM
788
+
789
+ >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
790
+ >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
791
+
792
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
793
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
794
+
795
+ >>> # Generate
796
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
797
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
798
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
799
+ ```"""
800
+
801
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
802
+ output_hidden_states = (
803
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
804
+ )
805
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
806
+
807
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
808
+ outputs = self.model(
809
+ input_ids=input_ids,
810
+ attention_mask=attention_mask,
811
+ position_ids=position_ids,
812
+ past_key_values=past_key_values,
813
+ inputs_embeds=inputs_embeds,
814
+ use_cache=use_cache,
815
+ output_attentions=output_attentions,
816
+ output_hidden_states=output_hidden_states,
817
+ return_dict=return_dict,
818
+ )
819
+
820
+ hidden_states = outputs[0]
821
+ if self.pretraining_tp > 1:
822
+ lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.pretraining_tp, dim=0)
823
+ logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.pretraining_tp)]
824
+ logits = torch.cat(logits, dim=-1)
825
+ else:
826
+ logits = self.lm_head(hidden_states)
827
+ logits = logits.float()
828
+
829
+ loss = None
830
+ if labels is not None:
831
+ # Shift so that tokens < n predict n
832
+ shift_logits = logits[..., :-1, :].contiguous()
833
+ shift_labels = labels[..., 1:].contiguous()
834
+ # Flatten the tokens
835
+ loss_fct = CrossEntropyLoss()
836
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
837
+ shift_labels = shift_labels.view(-1)
838
+ # Enable model parallelism
839
+ shift_labels = shift_labels.to(shift_logits.device)
840
+ loss = loss_fct(shift_logits, shift_labels)
841
+
842
+ if not return_dict:
843
+ output = (logits,) + outputs[1:]
844
+ return (loss,) + output if loss is not None else output
845
+
846
+ return CausalLMOutputWithPast(
847
+ loss=loss,
848
+ logits=logits,
849
+ past_key_values=outputs.past_key_values,
850
+ hidden_states=outputs.hidden_states,
851
+ attentions=outputs.attentions,
852
+ )
853
+
854
+ def prepare_inputs_for_generation(
855
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
856
+ ):
857
+ if past_key_values:
858
+ input_ids = input_ids[:, -1:]
859
+
860
+ position_ids = kwargs.get("position_ids", None)
861
+ if attention_mask is not None and position_ids is None:
862
+ # create position_ids on the fly for batch generation
863
+ position_ids = attention_mask.long().cumsum(-1) - 1
864
+ position_ids.masked_fill_(attention_mask == 0, 1)
865
+ if past_key_values:
866
+ position_ids = position_ids[:, -1].unsqueeze(-1)
867
+
868
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
869
+ if inputs_embeds is not None and past_key_values is None:
870
+ model_inputs = {"inputs_embeds": inputs_embeds}
871
+ else:
872
+ model_inputs = {"input_ids": input_ids}
873
+
874
+ model_inputs.update(
875
+ {
876
+ "position_ids": position_ids,
877
+ "past_key_values": past_key_values,
878
+ "use_cache": kwargs.get("use_cache"),
879
+ "attention_mask": attention_mask,
880
+ }
881
+ )
882
+ return model_inputs
883
+
884
+ @staticmethod
885
+ def _reorder_cache(past_key_values, beam_idx):
886
+ reordered_past = ()
887
+ for layer_past in past_key_values:
888
+ reordered_past += (
889
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
890
+ )
891
+ return reordered_past
892
+
893
+
894
+ @add_start_docstrings(
895
+ """
896
+ The LLaMa Model transformer with a sequence classification head on top (linear layer).
897
+
898
+ [`LlamaForSequenceClassification`] uses the last token in order to do the classification, as other causal models
899
+ (e.g. GPT-2) do.
900
+
901
+ Since it does classification on the last token, it requires to know the position of the last token. If a
902
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
903
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
904
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
905
+ each row of the batch).
906
+ """,
907
+ LLAMA_START_DOCSTRING,
908
+ )
909
+ class LlamaForSequenceClassification(LlamaPreTrainedModel):
910
+ def __init__(self, config):
911
+ super().__init__(config)
912
+ self.num_labels = config.num_labels
913
+ self.model = LlamaModel(config)
914
+ self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
915
+
916
+ # Initialize weights and apply final processing
917
+ self.post_init()
918
+
919
+ def get_input_embeddings(self):
920
+ return self.model.embed_tokens
921
+
922
+ def set_input_embeddings(self, value):
923
+ self.model.embed_tokens = value
924
+
925
+ @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
926
+ def forward(
927
+ self,
928
+ input_ids: torch.LongTensor = None,
929
+ attention_mask: Optional[torch.Tensor] = None,
930
+ position_ids: Optional[torch.LongTensor] = None,
931
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
932
+ inputs_embeds: Optional[torch.FloatTensor] = None,
933
+ labels: Optional[torch.LongTensor] = None,
934
+ use_cache: Optional[bool] = None,
935
+ output_attentions: Optional[bool] = None,
936
+ output_hidden_states: Optional[bool] = None,
937
+ return_dict: Optional[bool] = None,
938
+ ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
939
+ r"""
940
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
941
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
942
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
943
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
944
+ """
945
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
946
+
947
+ transformer_outputs = self.model(
948
+ input_ids,
949
+ attention_mask=attention_mask,
950
+ position_ids=position_ids,
951
+ past_key_values=past_key_values,
952
+ inputs_embeds=inputs_embeds,
953
+ use_cache=use_cache,
954
+ output_attentions=output_attentions,
955
+ output_hidden_states=output_hidden_states,
956
+ return_dict=return_dict,
957
+ )
958
+ hidden_states = transformer_outputs[0]
959
+ logits = self.score(hidden_states)
960
+
961
+ if input_ids is not None:
962
+ batch_size = input_ids.shape[0]
963
+ else:
964
+ batch_size = inputs_embeds.shape[0]
965
+
966
+ if self.config.pad_token_id is None and batch_size != 1:
967
+ raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
968
+ if self.config.pad_token_id is None:
969
+ sequence_lengths = -1
970
+ else:
971
+ if input_ids is not None:
972
+ sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device)
973
+ else:
974
+ sequence_lengths = -1
975
+
976
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
977
+
978
+ loss = None
979
+ if labels is not None:
980
+ labels = labels.to(logits.device)
981
+ if self.config.problem_type is None:
982
+ if self.num_labels == 1:
983
+ self.config.problem_type = "regression"
984
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
985
+ self.config.problem_type = "single_label_classification"
986
+ else:
987
+ self.config.problem_type = "multi_label_classification"
988
+
989
+ if self.config.problem_type == "regression":
990
+ loss_fct = MSELoss()
991
+ if self.num_labels == 1:
992
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
993
+ else:
994
+ loss = loss_fct(pooled_logits, labels)
995
+ elif self.config.problem_type == "single_label_classification":
996
+ loss_fct = CrossEntropyLoss()
997
+ loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
998
+ elif self.config.problem_type == "multi_label_classification":
999
+ loss_fct = BCEWithLogitsLoss()
1000
+ loss = loss_fct(pooled_logits, labels)
1001
+ if not return_dict:
1002
+ output = (pooled_logits,) + transformer_outputs[1:]
1003
+ return ((loss,) + output) if loss is not None else output
1004
+
1005
+ return SequenceClassifierOutputWithPast(
1006
+ loss=loss,
1007
+ logits=pooled_logits,
1008
+ past_key_values=transformer_outputs.past_key_values,
1009
+ hidden_states=transformer_outputs.hidden_states,
1010
+ attentions=transformer_outputs.attentions,
1011
+ )
multimodal_encoder/visual.py ADDED
@@ -0,0 +1,452 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # Adapted from https://github.com/microsoft/unilm/tree/master/beit
3
+ # --------------------------------------------------------
4
+
5
+ import os
6
+ from functools import partial
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ from torch.utils.checkpoint import checkpoint
12
+
13
+ try:
14
+ from timm.models.layers import drop_path, to_2tuple
15
+ except:
16
+ from timm.layers import drop_path, to_2tuple
17
+
18
+ try:
19
+ import xformers.ops as xops
20
+ except ImportError:
21
+ xops = None
22
+ print("Please 'pip install xformers'")
23
+
24
+
25
+ class PatchDropout(nn.Module):
26
+ """
27
+ https://arxiv.org/abs/2212.00794
28
+ """
29
+
30
+ def __init__(self, prob, exclude_first_token=True):
31
+ super().__init__()
32
+ assert 0 <= prob < 1.
33
+ self.prob = prob
34
+ self.exclude_first_token = exclude_first_token # exclude CLS token
35
+ print(f"os.getenv('RoPE')={os.getenv('RoPE')}")
36
+
37
+ def forward(self, x):
38
+ if not self.training or self.prob == 0.:
39
+ return x
40
+
41
+ if self.exclude_first_token:
42
+ cls_tokens, x = x[:, :1], x[:, 1:]
43
+ else:
44
+ cls_tokens = torch.jit.annotate(torch.Tensor, x[:, :1])
45
+
46
+ batch = x.size()[0]
47
+ num_tokens = x.size()[1]
48
+
49
+ batch_indices = torch.arange(batch)
50
+ batch_indices = batch_indices[..., None]
51
+
52
+ keep_prob = 1 - self.prob
53
+ num_patches_keep = max(1, int(num_tokens * keep_prob))
54
+
55
+ rand = torch.randn(batch, num_tokens)
56
+ patch_indices_keep = rand.topk(num_patches_keep, dim=-1).indices
57
+
58
+ x = x[batch_indices, patch_indices_keep]
59
+
60
+ if self.exclude_first_token:
61
+ x = torch.cat((cls_tokens, x), dim=1)
62
+
63
+ if self.training and os.getenv('RoPE') == '1':
64
+ return x, patch_indices_keep
65
+
66
+ return x
67
+
68
+ class DropPath(nn.Module):
69
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
70
+ """
71
+ def __init__(self, drop_prob=None):
72
+ super(DropPath, self).__init__()
73
+ self.drop_prob = drop_prob
74
+
75
+ def forward(self, x):
76
+ return drop_path(x, self.drop_prob, self.training)
77
+
78
+ def extra_repr(self) -> str:
79
+ return 'p={}'.format(self.drop_prob)
80
+
81
+
82
+ class Mlp(nn.Module):
83
+ def __init__(
84
+ self,
85
+ in_features,
86
+ hidden_features=None,
87
+ out_features=None,
88
+ act_layer=nn.GELU,
89
+ norm_layer=nn.LayerNorm,
90
+ drop=0.,
91
+ subln=False,
92
+
93
+ ):
94
+ super().__init__()
95
+ out_features = out_features or in_features
96
+ hidden_features = hidden_features or in_features
97
+ self.fc1 = nn.Linear(in_features, hidden_features)
98
+ self.act = act_layer()
99
+
100
+ self.ffn_ln = norm_layer(hidden_features) if subln else nn.Identity()
101
+
102
+ self.fc2 = nn.Linear(hidden_features, out_features)
103
+ self.drop = nn.Dropout(drop)
104
+
105
+ def forward(self, x):
106
+ x = self.fc1(x)
107
+ x = self.act(x)
108
+ # x = self.drop(x)
109
+ # commit this for the orignal BERT implement
110
+ x = self.ffn_ln(x)
111
+
112
+ x = self.fc2(x)
113
+ x = self.drop(x)
114
+ return x
115
+
116
+ class SwiGLU(nn.Module):
117
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.SiLU, drop=0.,
118
+ norm_layer=nn.LayerNorm, subln=False):
119
+ super().__init__()
120
+ out_features = out_features or in_features
121
+ hidden_features = hidden_features or in_features
122
+
123
+ self.w1 = nn.Linear(in_features, hidden_features)
124
+ self.w2 = nn.Linear(in_features, hidden_features)
125
+
126
+ self.act = act_layer()
127
+ self.ffn_ln = norm_layer(hidden_features) if subln else nn.Identity()
128
+ self.w3 = nn.Linear(hidden_features, out_features)
129
+
130
+ self.drop = nn.Dropout(drop)
131
+
132
+ def forward(self, x):
133
+ x1 = self.w1(x)
134
+ x2 = self.w2(x)
135
+ hidden = self.act(x1) * x2
136
+ x = self.ffn_ln(hidden)
137
+ x = self.w3(x)
138
+ x = self.drop(x)
139
+ return x
140
+
141
+ class Attention(nn.Module):
142
+ def __init__(
143
+ self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
144
+ proj_drop=0., window_size=None, attn_head_dim=None, xattn=False, rope=None, subln=False, norm_layer=nn.LayerNorm):
145
+ super().__init__()
146
+ self.num_heads = num_heads
147
+ head_dim = dim // num_heads
148
+ if attn_head_dim is not None:
149
+ head_dim = attn_head_dim
150
+ all_head_dim = head_dim * self.num_heads
151
+ self.scale = qk_scale or head_dim ** -0.5
152
+
153
+ self.subln = subln
154
+ if self.subln:
155
+ self.q_proj = nn.Linear(dim, all_head_dim, bias=False)
156
+ self.k_proj = nn.Linear(dim, all_head_dim, bias=False)
157
+ self.v_proj = nn.Linear(dim, all_head_dim, bias=False)
158
+ else:
159
+ if qkv_bias:
160
+ self.qkv = nn.Linear(dim, all_head_dim * 3, bias=True)
161
+ else:
162
+ self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
163
+
164
+ # if qkv_bias:
165
+ # self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
166
+ # self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
167
+ # qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
168
+ # self.qkv.bias.data = qkv_bias
169
+ # else:
170
+ # self.q_bias = None
171
+ # self.v_bias = None
172
+
173
+ self.window_size = None
174
+ self.relative_position_bias_table = None
175
+ self.relative_position_index = None
176
+
177
+ self.attn_drop = nn.Dropout(attn_drop)
178
+ self.inner_attn_ln = norm_layer(all_head_dim) if subln else nn.Identity()
179
+ # self.proj = nn.Linear(all_head_dim, all_head_dim)
180
+ self.proj = nn.Linear(all_head_dim, dim)
181
+ self.proj_drop = nn.Dropout(proj_drop)
182
+ self.xattn = xattn
183
+ self.xattn_drop = attn_drop
184
+
185
+ self.rope = rope
186
+
187
+ def forward(self, x, rel_pos_bias=None, attn_mask=None):
188
+ B, N, C = x.shape
189
+ if self.subln:
190
+ q = F.linear(input=x, weight=self.q_proj.weight, bias=self.q_bias)
191
+ k = F.linear(input=x, weight=self.k_proj.weight, bias=None)
192
+ v = F.linear(input=x, weight=self.v_proj.weight, bias=self.v_bias)
193
+
194
+ q = q.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3) # B, num_heads, N, C
195
+ k = k.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
196
+ v = v.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
197
+ else:
198
+
199
+ # qkv_bias = None
200
+ # if self.q_bias is not None:
201
+ # qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
202
+
203
+ # qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
204
+
205
+ qkv = self.qkv(x)
206
+
207
+ qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) # 3, B, num_heads, N, C
208
+ q, k, v = qkv[0], qkv[1], qkv[2]
209
+
210
+ if self.rope:
211
+ q_t = q[:, :, 1:, :]
212
+ ro_q_t = self.rope(q_t)
213
+ q = torch.cat((q[:, :, :1, :], ro_q_t), -2).type_as(v)
214
+
215
+ k_t = k[:, :, 1:, :]
216
+ ro_k_t = self.rope(k_t)
217
+ k = torch.cat((k[:, :, :1, :], ro_k_t), -2).type_as(v)
218
+
219
+ if self.xattn:
220
+ q = q.permute(0, 2, 1, 3) # B, num_heads, N, C -> B, N, num_heads, C
221
+ k = k.permute(0, 2, 1, 3)
222
+ v = v.permute(0, 2, 1, 3)
223
+
224
+ x = xops.memory_efficient_attention(
225
+ q, k, v,
226
+ p=self.xattn_drop,
227
+ scale=self.scale,
228
+ )
229
+ x = x.reshape(B, N, -1)
230
+ x = self.inner_attn_ln(x)
231
+ x = self.proj(x)
232
+ x = self.proj_drop(x)
233
+ else:
234
+ q = q * self.scale
235
+ attn = (q @ k.transpose(-2, -1))
236
+
237
+ if self.relative_position_bias_table is not None:
238
+ relative_position_bias = \
239
+ self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
240
+ self.window_size[0] * self.window_size[1] + 1,
241
+ self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
242
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
243
+ attn = attn + relative_position_bias.unsqueeze(0).type_as(attn)
244
+
245
+ if rel_pos_bias is not None:
246
+ attn = attn + rel_pos_bias.type_as(attn)
247
+
248
+ if attn_mask is not None:
249
+ attn_mask = attn_mask.bool()
250
+ attn = attn.masked_fill(~attn_mask[:, None, None, :], float("-inf"))
251
+
252
+ attn = attn.softmax(dim=-1)
253
+ attn = self.attn_drop(attn)
254
+
255
+ x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
256
+ x = self.inner_attn_ln(x)
257
+ x = self.proj(x)
258
+ x = self.proj_drop(x)
259
+ return x
260
+
261
+
262
+ class Block(nn.Module):
263
+
264
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
265
+ drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm,
266
+ window_size=None, attn_head_dim=None, xattn=False, rope=None, postnorm=False,
267
+ subln=False, naiveswiglu=False):
268
+ super().__init__()
269
+ self.norm1 = norm_layer(dim)
270
+ self.attn = Attention(
271
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
272
+ attn_drop=attn_drop, proj_drop=drop, window_size=window_size, attn_head_dim=attn_head_dim,
273
+ xattn=xattn, rope=rope, subln=subln, norm_layer=norm_layer)
274
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
275
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
276
+ self.norm2 = norm_layer(dim)
277
+ mlp_hidden_dim = int(dim * mlp_ratio)
278
+
279
+ if naiveswiglu:
280
+ self.mlp = SwiGLU(
281
+ in_features=dim,
282
+ hidden_features=mlp_hidden_dim,
283
+ subln=subln,
284
+ norm_layer=norm_layer,
285
+ )
286
+ else:
287
+ self.mlp = Mlp(
288
+ in_features=dim,
289
+ hidden_features=mlp_hidden_dim,
290
+ act_layer=act_layer,
291
+ subln=subln,
292
+ drop=drop
293
+ )
294
+
295
+ if init_values is not None and init_values > 0:
296
+ self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
297
+ self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
298
+ else:
299
+ self.gamma_1, self.gamma_2 = None, None
300
+
301
+ self.postnorm = postnorm
302
+
303
+ def forward(self, x, rel_pos_bias=None, attn_mask=None):
304
+ if self.gamma_1 is None:
305
+ if self.postnorm:
306
+ x = x + self.drop_path(self.norm1(self.attn(x, rel_pos_bias=rel_pos_bias, attn_mask=attn_mask)))
307
+ x = x + self.drop_path(self.norm2(self.mlp(x)))
308
+ else:
309
+ x = x + self.drop_path(self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias, attn_mask=attn_mask))
310
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
311
+ else:
312
+ if self.postnorm:
313
+ x = x + self.drop_path(self.gamma_1 * self.norm1(self.attn(x, rel_pos_bias=rel_pos_bias, attn_mask=attn_mask)))
314
+ x = x + self.drop_path(self.gamma_2 * self.norm2(self.mlp(x)))
315
+ else:
316
+ x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias, attn_mask=attn_mask))
317
+ x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
318
+ return x
319
+
320
+
321
+ class PatchEmbed(nn.Module):
322
+ """ Image to Patch Embedding
323
+ """
324
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
325
+ super().__init__()
326
+ img_size = to_2tuple(img_size)
327
+ patch_size = to_2tuple(patch_size)
328
+ num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
329
+ self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
330
+ self.img_size = img_size
331
+ self.patch_size = patch_size
332
+ self.num_patches = num_patches
333
+
334
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
335
+
336
+ def forward(self, x, **kwargs):
337
+ B, C, H, W = x.shape
338
+ # FIXME look at relaxing size constraints
339
+ assert H == self.img_size[0] and W == self.img_size[1], \
340
+ f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
341
+ x = self.proj(x).flatten(2).transpose(1, 2)
342
+ return x
343
+
344
+
345
+ class EVAVisionTransformer(nn.Module):
346
+ """ Vision Transformer with support for patch or hybrid CNN input stage
347
+ """
348
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
349
+ num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
350
+ drop_path_rate=0., norm_layer=nn.LayerNorm, init_values=None, patch_dropout=0.,
351
+ use_abs_pos_emb=True, use_rel_pos_bias=False, use_shared_rel_pos_bias=False, rope=False,
352
+ use_mean_pooling=True, init_scale=0.001, grad_checkpointing=False, xattn=False, postnorm=False,
353
+ pt_hw_seq_len=16, intp_freq=False, naiveswiglu=False, subln=False,
354
+ ):
355
+ super().__init__()
356
+ self.image_size = img_size
357
+ # self.num_classes = num_classes
358
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
359
+
360
+ self.patch_embed = PatchEmbed(
361
+ img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
362
+ num_patches = self.patch_embed.num_patches
363
+
364
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
365
+ if use_abs_pos_emb:
366
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
367
+ else:
368
+ self.pos_embed = None
369
+ self.pos_drop = nn.Dropout(p=drop_rate)
370
+
371
+ self.rel_pos_bias = None
372
+ self.rope = None
373
+
374
+ self.naiveswiglu = naiveswiglu
375
+
376
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
377
+ self.use_rel_pos_bias = use_rel_pos_bias
378
+ self.blocks = nn.ModuleList([
379
+ Block(
380
+ dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
381
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
382
+ init_values=init_values, window_size=self.patch_embed.patch_shape if use_rel_pos_bias else None,
383
+ xattn=xattn, rope=self.rope, postnorm=postnorm, subln=subln, naiveswiglu=naiveswiglu)
384
+ for i in range(depth)])
385
+
386
+ # setting a patch_dropout of 0. would mean it is disabled and this function would be the identity fn
387
+ self.patch_dropout = PatchDropout(patch_dropout) if patch_dropout > 0. else nn.Identity()
388
+
389
+ self.grad_checkpointing = grad_checkpointing
390
+
391
+
392
+ def get_num_layers(self):
393
+ return len(self.blocks)
394
+
395
+ def lock(self, unlocked_groups=0, freeze_bn_stats=False):
396
+ assert unlocked_groups == 0, 'partial locking not currently supported for this model'
397
+ for param in self.parameters():
398
+ param.requires_grad = False
399
+
400
+ @torch.jit.ignore
401
+ def set_grad_checkpointing(self, enable=True):
402
+ self.grad_checkpointing = enable
403
+
404
+ @torch.jit.ignore
405
+ def no_weight_decay(self):
406
+ return {'pos_embed', 'cls_token'}
407
+
408
+
409
+ def forward_features(self, x):
410
+ x = self.patch_embed(x)
411
+ batch_size, seq_len, _ = x.size()
412
+
413
+ cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
414
+ x = torch.cat((cls_tokens, x), dim=1)
415
+ if self.pos_embed is not None:
416
+ x = x + self.pos_embed
417
+ x = self.pos_drop(x)
418
+
419
+ # a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in
420
+ if os.getenv('RoPE') == '1':
421
+ if self.training and not isinstance(self.patch_dropout, nn.Identity):
422
+ x, patch_indices_keep = self.patch_dropout(x)
423
+ self.rope.forward = partial(self.rope.forward, patch_indices_keep=patch_indices_keep)
424
+ else:
425
+ self.rope.forward = partial(self.rope.forward, patch_indices_keep=None)
426
+ x = self.patch_dropout(x)
427
+ else:
428
+ x = self.patch_dropout(x)
429
+
430
+ rel_pos_bias = None
431
+
432
+ for blk in self.blocks:
433
+ if self.grad_checkpointing:
434
+ x = checkpoint(blk, x, (rel_pos_bias,))
435
+ else:
436
+ x = blk(x, rel_pos_bias=rel_pos_bias)
437
+
438
+ return x
439
+
440
+ def forward(self, x):
441
+
442
+ """
443
+ :return:
444
+ forward_features function returns raw features of ViT,
445
+ forward with return_all_features returns normalized features of ViT
446
+ :param x:
447
+ :param return_all_features:
448
+ """
449
+
450
+ features = self.forward_features(x) # [B, n_patch, C]
451
+
452
+ return features
pipeline_emu2_gen.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ # ===========================================================================================
4
+ #
5
+ # Copyright (c) Beijing Academy of Artificial Intelligence (BAAI). All rights reserved.
6
+ #
7
+ # Author : Fan Zhang
8
+ # Email : zhangfan@baai.ac.cn
9
+ # Institute : Beijing Academy of Artificial Intelligence (BAAI)
10
+ # Create On : 2023-12-19 10:45
11
+ # Last Modified : 2023-12-19 14:01
12
+ # File Name : pipeline.py
13
+ # Description :
14
+ #
15
+ # ===========================================================================================
16
+
17
+ from dataclasses import dataclass
18
+ from typing import List, Optional, Union
19
+
20
+ from PIL import Image
21
+ import numpy as np
22
+ import torch
23
+ from torchvision import transforms as TF
24
+ from tqdm import tqdm
25
+
26
+ from diffusers import DiffusionPipeline
27
+ from diffusers.utils import BaseOutput
28
+
29
+ from diffusers import UNet2DConditionModel, EulerDiscreteScheduler, AutoencoderKL
30
+ from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
31
+ from transformers import CLIPImageProcessor
32
+ from transformers import AutoModelForCausalLM, AutoTokenizer
33
+
34
+ EVA_IMAGE_SIZE = 448
35
+ OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073)
36
+ OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711)
37
+ DEFAULT_IMG_PLACEHOLDER = "[<IMG_PLH>]"
38
+
39
+ @dataclass
40
+ class EmuVisualGenerationPipelineOutput(BaseOutput):
41
+ images: Union[List[Image.Image], np.ndarray]
42
+ nsfw_content_detected: Optional[List[bool]]
43
+
44
+
45
+ class EmuVisualGenerationPipeline(DiffusionPipeline):
46
+
47
+ def __init__(
48
+ self,
49
+ tokenizer: AutoTokenizer,
50
+ multimodal_encoder: AutoModelForCausalLM,
51
+ scheduler: EulerDiscreteScheduler,
52
+ unet: UNet2DConditionModel,
53
+ vae: AutoencoderKL,
54
+ feature_extractor: CLIPImageProcessor,
55
+ safety_checker: StableDiffusionSafetyChecker,
56
+ eva_size=EVA_IMAGE_SIZE,
57
+ eva_mean=OPENAI_DATASET_MEAN,
58
+ eva_std=OPENAI_DATASET_STD,
59
+ ):
60
+ super().__init__()
61
+ self.register_modules(
62
+ tokenizer=tokenizer,
63
+ multimodal_encoder=multimodal_encoder,
64
+ scheduler=scheduler,
65
+ unet=unet,
66
+ vae=vae,
67
+ feature_extractor=feature_extractor,
68
+ safety_checker=safety_checker,
69
+ )
70
+
71
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
72
+
73
+ self.transform = TF.Compose([
74
+ TF.Resize((eva_size, eva_size), interpolation=TF.InterpolationMode.BICUBIC),
75
+ TF.ToTensor(),
76
+ TF.Normalize(mean=eva_mean, std=eva_std),
77
+ ])
78
+
79
+ self.negative_prompt = None
80
+
81
+ def device(self, module):
82
+ return next(module.parameters()).device
83
+
84
+ def dtype(self, module):
85
+ return next(module.parameters()).dtype
86
+
87
+ @torch.no_grad()
88
+ def __call__(
89
+ self,
90
+ inputs: List[Image.Image | str] | str | Image.Image,
91
+ height: int = 1024,
92
+ width: int = 1024,
93
+ num_inference_steps: int = 50,
94
+ guidance_scale: float = 3.,
95
+ crop_info: List[int] = [0, 0],
96
+ original_size: List[int] = [1024, 1024],
97
+ ):
98
+ if not isinstance(inputs, list):
99
+ inputs = [inputs]
100
+
101
+ # 0. Default height and width to unet
102
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
103
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
104
+
105
+ device = self.device(self.unet)
106
+ dtype = self.dtype(self.unet)
107
+
108
+ do_classifier_free_guidance = guidance_scale > 1.0
109
+
110
+ # 1. Encode input prompt
111
+ prompt_embeds = self._prepare_and_encode_inputs(
112
+ inputs,
113
+ do_classifier_free_guidance,
114
+ ).to(dtype).to(device)
115
+ batch_size = prompt_embeds.shape[0] // 2 if do_classifier_free_guidance else prompt_embeds.shape[0]
116
+
117
+ unet_added_conditions = {}
118
+ time_ids = torch.LongTensor(original_size + crop_info + [height, width]).to(device)
119
+ if do_classifier_free_guidance:
120
+ unet_added_conditions["time_ids"] = torch.cat([time_ids, time_ids], dim=0)
121
+ else:
122
+ unet_added_conditions["time_ids"] = time_ids
123
+ unet_added_conditions["text_embeds"] = torch.mean(prompt_embeds, dim=1)
124
+
125
+ # 2. Prepare timesteps
126
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
127
+ timesteps = self.scheduler.timesteps
128
+
129
+ # 3. Prepare latent variables
130
+ shape = (
131
+ batch_size,
132
+ self.unet.config.in_channels,
133
+ height // self.vae_scale_factor,
134
+ width // self.vae_scale_factor,
135
+ )
136
+ latents = torch.randn(shape, device=device, dtype=dtype)
137
+ latents = latents * self.scheduler.init_noise_sigma
138
+
139
+ # 4. Denoising loop
140
+ for t in tqdm(timesteps):
141
+ # expand the latents if we are doing classifier free guidance
142
+ # 2B x 4 x H x W
143
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
144
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
145
+
146
+ noise_pred = self.unet(
147
+ latent_model_input,
148
+ t,
149
+ encoder_hidden_states=prompt_embeds,
150
+ added_cond_kwargs=unet_added_conditions,
151
+ ).sample
152
+
153
+ # perform guidance
154
+ if do_classifier_free_guidance:
155
+ noise_pred_cond, noise_pred_uncond = noise_pred.chunk(2)
156
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
157
+
158
+ # compute the previous noisy sample x_t -> x_t-1
159
+ latents = self.scheduler.step(noise_pred, t, latents).prev_sample
160
+
161
+ # 5. Post-processing
162
+ images = self.decode_latents(latents)
163
+
164
+ # 6. Run safety checker
165
+ images, has_nsfw_concept = self.run_safety_checker(images)
166
+
167
+ # 7. Convert to PIL
168
+ images = self.numpy_to_pil(images)
169
+ return EmuVisualGenerationPipelineOutput(images=images, nsfw_content_detected=has_nsfw_concept)
170
+
171
+ def _prepare_and_encode_inputs(
172
+ self,
173
+ inputs: List[str | Image.Image],
174
+ do_classifier_free_guidance: bool = False,
175
+ placeholder: str = DEFAULT_IMG_PLACEHOLDER,
176
+ ):
177
+ device = self.device(self.multimodal_encoder.model.visual)
178
+ dtype = self.dtype(self.multimodal_encoder.model.visual)
179
+
180
+ text_prompt, image_prompt = "", []
181
+ for x in inputs:
182
+ if isinstance(x, str):
183
+ text_prompt += x
184
+ else:
185
+ text_prompt += placeholder
186
+ image_prompt.append(self.transform(x))
187
+
188
+ if len(image_prompt) == 0:
189
+ image_prompt = None
190
+ else:
191
+ image_prompt = torch.stack(image_prompt)
192
+ image_prompt = image_prompt.type(dtype).to(device)
193
+
194
+ prompt = self.multimodal_encoder.generate_image(text=[text_prompt], image=image_prompt, tokenizer=self.tokenizer)
195
+ if do_classifier_free_guidance:
196
+ if self.negative_prompt is None:
197
+ self.negative_prompt = self.multimodal_encoder.generate_image(text=[""], tokenizer=self.tokenizer)
198
+ prompt = torch.cat([prompt, self.negative_prompt], dim=0)
199
+
200
+ return prompt
201
+
202
+ def decode_latents(self, latents: torch.Tensor) -> np.ndarray:
203
+ latents = 1 / self.vae.config.scaling_factor * latents
204
+ image = self.vae.decode(latents).sample
205
+ image = (image / 2 + 0.5).clamp(0, 1)
206
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
207
+ return image
208
+
209
+ def numpy_to_pil(self, images: np.ndarray) -> List[Image.Image]:
210
+ """
211
+ Convert a numpy image or a batch of images to a PIL image.
212
+ """
213
+ if images.ndim == 3:
214
+ images = images[None, ...]
215
+ images = (images * 255).round().astype("uint8")
216
+ if images.shape[-1] == 1:
217
+ # special case for grayscale (single channel) images
218
+ pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
219
+ else:
220
+ pil_images = [Image.fromarray(image) for image in images]
221
+
222
+ return pil_images
223
+
224
+ def run_safety_checker(self, images: np.ndarray):
225
+ if self.safety_checker is not None:
226
+ device = self.device(self.safety_checker)
227
+ dtype = self.dtype(self.safety_checker)
228
+ safety_checker_input = self.feature_extractor(self.numpy_to_pil(images), return_tensors="pt").to(device)
229
+ images, has_nsfw_concept = self.safety_checker(
230
+ images=images, clip_input=safety_checker_input.pixel_values.to(dtype)
231
+ )
232
+ else:
233
+ has_nsfw_concept = None
234
+ return images, has_nsfw_concept
safety_checker/config.json ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_commit_hash": null,
3
+ "_name_or_path": "/share/project/quansun/release_hf/Emu2-VisualGeneration/safety_checker",
4
+ "architectures": [
5
+ "StableDiffusionSafetyChecker"
6
+ ],
7
+ "initializer_factor": 1.0,
8
+ "logit_scale_init_value": 2.6592,
9
+ "model_type": "clip",
10
+ "projection_dim": 768,
11
+ "text_config": {
12
+ "_name_or_path": "",
13
+ "add_cross_attention": false,
14
+ "architectures": null,
15
+ "attention_dropout": 0.0,
16
+ "bad_words_ids": null,
17
+ "begin_suppress_tokens": null,
18
+ "bos_token_id": 49406,
19
+ "chunk_size_feed_forward": 0,
20
+ "cross_attention_hidden_size": null,
21
+ "decoder_start_token_id": null,
22
+ "diversity_penalty": 0.0,
23
+ "do_sample": false,
24
+ "dropout": 0.0,
25
+ "early_stopping": false,
26
+ "encoder_no_repeat_ngram_size": 0,
27
+ "eos_token_id": 49407,
28
+ "exponential_decay_length_penalty": null,
29
+ "finetuning_task": null,
30
+ "forced_bos_token_id": null,
31
+ "forced_eos_token_id": null,
32
+ "hidden_act": "quick_gelu",
33
+ "hidden_size": 768,
34
+ "id2label": {
35
+ "0": "LABEL_0",
36
+ "1": "LABEL_1"
37
+ },
38
+ "initializer_factor": 1.0,
39
+ "initializer_range": 0.02,
40
+ "intermediate_size": 3072,
41
+ "is_decoder": false,
42
+ "is_encoder_decoder": false,
43
+ "label2id": {
44
+ "LABEL_0": 0,
45
+ "LABEL_1": 1
46
+ },
47
+ "layer_norm_eps": 1e-05,
48
+ "length_penalty": 1.0,
49
+ "max_length": 20,
50
+ "max_position_embeddings": 77,
51
+ "min_length": 0,
52
+ "model_type": "clip_text_model",
53
+ "no_repeat_ngram_size": 0,
54
+ "num_attention_heads": 12,
55
+ "num_beam_groups": 1,
56
+ "num_beams": 1,
57
+ "num_hidden_layers": 12,
58
+ "num_return_sequences": 1,
59
+ "output_attentions": false,
60
+ "output_hidden_states": false,
61
+ "output_scores": false,
62
+ "pad_token_id": 1,
63
+ "prefix": null,
64
+ "problem_type": null,
65
+ "projection_dim": 512,
66
+ "pruned_heads": {},
67
+ "remove_invalid_values": false,
68
+ "repetition_penalty": 1.0,
69
+ "return_dict": true,
70
+ "return_dict_in_generate": false,
71
+ "sep_token_id": null,
72
+ "suppress_tokens": null,
73
+ "task_specific_params": null,
74
+ "temperature": 1.0,
75
+ "tf_legacy_loss": false,
76
+ "tie_encoder_decoder": false,
77
+ "tie_word_embeddings": true,
78
+ "tokenizer_class": null,
79
+ "top_k": 50,
80
+ "top_p": 1.0,
81
+ "torch_dtype": null,
82
+ "torchscript": false,
83
+ "transformers_version": "4.31.0",
84
+ "typical_p": 1.0,
85
+ "use_bfloat16": false,
86
+ "vocab_size": 49408
87
+ },
88
+ "torch_dtype": "bfloat16",
89
+ "transformers_version": null,
90
+ "vision_config": {
91
+ "_name_or_path": "",
92
+ "add_cross_attention": false,
93
+ "architectures": null,
94
+ "attention_dropout": 0.0,
95
+ "bad_words_ids": null,
96
+ "begin_suppress_tokens": null,
97
+ "bos_token_id": null,
98
+ "chunk_size_feed_forward": 0,
99
+ "cross_attention_hidden_size": null,
100
+ "decoder_start_token_id": null,
101
+ "diversity_penalty": 0.0,
102
+ "do_sample": false,
103
+ "dropout": 0.0,
104
+ "early_stopping": false,
105
+ "encoder_no_repeat_ngram_size": 0,
106
+ "eos_token_id": null,
107
+ "exponential_decay_length_penalty": null,
108
+ "finetuning_task": null,
109
+ "forced_bos_token_id": null,
110
+ "forced_eos_token_id": null,
111
+ "hidden_act": "quick_gelu",
112
+ "hidden_size": 1024,
113
+ "id2label": {
114
+ "0": "LABEL_0",
115
+ "1": "LABEL_1"
116
+ },
117
+ "image_size": 224,
118
+ "initializer_factor": 1.0,
119
+ "initializer_range": 0.02,
120
+ "intermediate_size": 4096,
121
+ "is_decoder": false,
122
+ "is_encoder_decoder": false,
123
+ "label2id": {
124
+ "LABEL_0": 0,
125
+ "LABEL_1": 1
126
+ },
127
+ "layer_norm_eps": 1e-05,
128
+ "length_penalty": 1.0,
129
+ "max_length": 20,
130
+ "min_length": 0,
131
+ "model_type": "clip_vision_model",
132
+ "no_repeat_ngram_size": 0,
133
+ "num_attention_heads": 16,
134
+ "num_beam_groups": 1,
135
+ "num_beams": 1,
136
+ "num_channels": 3,
137
+ "num_hidden_layers": 24,
138
+ "num_return_sequences": 1,
139
+ "output_attentions": false,
140
+ "output_hidden_states": false,
141
+ "output_scores": false,
142
+ "pad_token_id": null,
143
+ "patch_size": 14,
144
+ "prefix": null,
145
+ "problem_type": null,
146
+ "projection_dim": 512,
147
+ "pruned_heads": {},
148
+ "remove_invalid_values": false,
149
+ "repetition_penalty": 1.0,
150
+ "return_dict": true,
151
+ "return_dict_in_generate": false,
152
+ "sep_token_id": null,
153
+ "suppress_tokens": null,
154
+ "task_specific_params": null,
155
+ "temperature": 1.0,
156
+ "tf_legacy_loss": false,
157
+ "tie_encoder_decoder": false,
158
+ "tie_word_embeddings": true,
159
+ "tokenizer_class": null,
160
+ "top_k": 50,
161
+ "top_p": 1.0,
162
+ "torch_dtype": null,
163
+ "torchscript": false,
164
+ "transformers_version": "4.31.0",
165
+ "typical_p": 1.0,
166
+ "use_bfloat16": false
167
+ }
168
+ }
safety_checker/model.bf16.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:013ddb2eb3e3ddb6b91fd739de8abbc8281de91f2ae9f5067ac8586d6aa29cf6
3
+ size 608016672
scheduler/scheduler_config.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "EulerDiscreteScheduler",
3
+ "_diffusers_version": "0.21.2",
4
+ "beta_end": 0.012,
5
+ "beta_schedule": "scaled_linear",
6
+ "beta_start": 0.00085,
7
+ "clip_sample": false,
8
+ "interpolation_type": "linear",
9
+ "num_train_timesteps": 1000,
10
+ "prediction_type": "epsilon",
11
+ "sample_max_value": 1.0,
12
+ "set_alpha_to_one": false,
13
+ "skip_prk_steps": true,
14
+ "steps_offset": 1,
15
+ "timestep_spacing": "leading",
16
+ "trained_betas": null,
17
+ "use_karras_sigmas": false
18
+ }
tokenizer/added_tokens.json ADDED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "</delimiter_of_multi_objects/>": 32013,
3
+ "</object>": 32012,
4
+ "</phrase>": 32010,
5
+ "<REC>": 32014,
6
+ "<grounding>": 32008,
7
+ "<image>": 32003,
8
+ "<object>": 32011,
9
+ "<patch_index_0000>": 32015,
10
+ "<patch_index_0001>": 32016,
11
+ "<patch_index_0002>": 32017,
12
+ "<patch_index_0003>": 32018,
13
+ "<patch_index_0004>": 32019,
14
+ "<patch_index_0005>": 32020,
15
+ "<patch_index_0006>": 32021,
16
+ "<patch_index_0007>": 32022,
17
+ "<patch_index_0008>": 32023,
18
+ "<patch_index_0009>": 32024,
19
+ "<patch_index_0010>": 32025,
20
+ "<patch_index_0011>": 32026,
21
+ "<patch_index_0012>": 32027,
22
+ "<patch_index_0013>": 32028,
23
+ "<patch_index_0014>": 32029,
24
+ "<patch_index_0015>": 32030,
25
+ "<patch_index_0016>": 32031,
26
+ "<patch_index_0017>": 32032,
27
+ "<patch_index_0018>": 32033,
28
+ "<patch_index_0019>": 32034,
29
+ "<patch_index_0020>": 32035,
30
+ "<patch_index_0021>": 32036,
31
+ "<patch_index_0022>": 32037,
32
+ "<patch_index_0023>": 32038,
33
+ "<patch_index_0024>": 32039,
34
+ "<patch_index_0025>": 32040,
35
+ "<patch_index_0026>": 32041,
36
+ "<patch_index_0027>": 32042,
37
+ "<patch_index_0028>": 32043,
38
+ "<patch_index_0029>": 32044,
39
+ "<patch_index_0030>": 32045,
40
+ "<patch_index_0031>": 32046,
41
+ "<patch_index_0032>": 32047,
42
+ "<patch_index_0033>": 32048,
43
+ "<patch_index_0034>": 32049,
44
+ "<patch_index_0035>": 32050,
45
+ "<patch_index_0036>": 32051,
46
+ "<patch_index_0037>": 32052,
47
+ "<patch_index_0038>": 32053,
48
+ "<patch_index_0039>": 32054,
49
+ "<patch_index_0040>": 32055,
50
+ "<patch_index_0041>": 32056,
51
+ "<patch_index_0042>": 32057,
52
+ "<patch_index_0043>": 32058,
53
+ "<patch_index_0044>": 32059,
54
+ "<patch_index_0045>": 32060,
55
+ "<patch_index_0046>": 32061,
56
+ "<patch_index_0047>": 32062,
57
+ "<patch_index_0048>": 32063,
58
+ "<patch_index_0049>": 32064,
59
+ "<patch_index_0050>": 32065,
60
+ "<patch_index_0051>": 32066,
61
+ "<patch_index_0052>": 32067,
62
+ "<patch_index_0053>": 32068,
63
+ "<patch_index_0054>": 32069,
64
+ "<patch_index_0055>": 32070,
65
+ "<patch_index_0056>": 32071,
66
+ "<patch_index_0057>": 32072,
67
+ "<patch_index_0058>": 32073,
68
+ "<patch_index_0059>": 32074,
69
+ "<patch_index_0060>": 32075,
70
+ "<patch_index_0061>": 32076,
71
+ "<patch_index_0062>": 32077,
72
+ "<patch_index_0063>": 32078,
73
+ "<patch_index_0064>": 32079,
74
+ "<patch_index_0065>": 32080,
75
+ "<patch_index_0066>": 32081,
76
+ "<patch_index_0067>": 32082,
77
+ "<patch_index_0068>": 32083,
78
+ "<patch_index_0069>": 32084,
79
+ "<patch_index_0070>": 32085,
80
+ "<patch_index_0071>": 32086,
81
+ "<patch_index_0072>": 32087,
82
+ "<patch_index_0073>": 32088,
83
+ "<patch_index_0074>": 32089,
84
+ "<patch_index_0075>": 32090,
85
+ "<patch_index_0076>": 32091,
86
+ "<patch_index_0077>": 32092,
87
+ "<patch_index_0078>": 32093,
88
+ "<patch_index_0079>": 32094,
89
+ "<patch_index_0080>": 32095,
90
+ "<patch_index_0081>": 32096,
91
+ "<patch_index_0082>": 32097,
92
+ "<patch_index_0083>": 32098,
93
+ "<patch_index_0084>": 32099,
94
+ "<patch_index_0085>": 32100,
95
+ "<patch_index_0086>": 32101,
96
+ "<patch_index_0087>": 32102,
97
+ "<patch_index_0088>": 32103,
98
+ "<patch_index_0089>": 32104,
99
+ "<patch_index_0090>": 32105,
100
+ "<patch_index_0091>": 32106,
101
+ "<patch_index_0092>": 32107,
102
+ "<patch_index_0093>": 32108,
103
+ "<patch_index_0094>": 32109,
104
+ "<patch_index_0095>": 32110,
105
+ "<patch_index_0096>": 32111,
106
+ "<patch_index_0097>": 32112,
107
+ "<patch_index_0098>": 32113,
108
+ "<patch_index_0099>": 32114,
109
+ "<patch_index_0100>": 32115,
110
+ "<patch_index_0101>": 32116,
111
+ "<patch_index_0102>": 32117,
112
+ "<patch_index_0103>": 32118,
113
+ "<patch_index_0104>": 32119,
114
+ "<patch_index_0105>": 32120,
115
+ "<patch_index_0106>": 32121,
116
+ "<patch_index_0107>": 32122,
117
+ "<patch_index_0108>": 32123,
118
+ "<patch_index_0109>": 32124,
119
+ "<patch_index_0110>": 32125,
120
+ "<patch_index_0111>": 32126,
121
+ "<patch_index_0112>": 32127,
122
+ "<patch_index_0113>": 32128,
123
+ "<patch_index_0114>": 32129,
124
+ "<patch_index_0115>": 32130,
125
+ "<patch_index_0116>": 32131,
126
+ "<patch_index_0117>": 32132,
127
+ "<patch_index_0118>": 32133,
128
+ "<patch_index_0119>": 32134,
129
+ "<patch_index_0120>": 32135,
130
+ "<patch_index_0121>": 32136,
131
+ "<patch_index_0122>": 32137,
132
+ "<patch_index_0123>": 32138,
133
+ "<patch_index_0124>": 32139,
134
+ "<patch_index_0125>": 32140,
135
+ "<patch_index_0126>": 32141,
136
+ "<patch_index_0127>": 32142,
137
+ "<patch_index_0128>": 32143,
138
+ "<patch_index_0129>": 32144,
139
+ "<patch_index_0130>": 32145,
140
+ "<patch_index_0131>": 32146,
141
+ "<patch_index_0132>": 32147,
142
+ "<patch_index_0133>": 32148,
143
+ "<patch_index_0134>": 32149,
144
+ "<patch_index_0135>": 32150,
145
+ "<patch_index_0136>": 32151,
146
+ "<patch_index_0137>": 32152,
147
+ "<patch_index_0138>": 32153,
148
+ "<patch_index_0139>": 32154,
149
+ "<patch_index_0140>": 32155,
150
+ "<patch_index_0141>": 32156,
151
+ "<patch_index_0142>": 32157,
152
+ "<patch_index_0143>": 32158,
153
+ "<patch_index_0144>": 32159,
154
+ "<patch_index_0145>": 32160,
155
+ "<patch_index_0146>": 32161,
156
+ "<patch_index_0147>": 32162,
157
+ "<patch_index_0148>": 32163,
158
+ "<patch_index_0149>": 32164,
159
+ "<patch_index_0150>": 32165,
160
+ "<patch_index_0151>": 32166,
161
+ "<patch_index_0152>": 32167,
162
+ "<patch_index_0153>": 32168,
163
+ "<patch_index_0154>": 32169,
164
+ "<patch_index_0155>": 32170,
165
+ "<patch_index_0156>": 32171,
166
+ "<patch_index_0157>": 32172,
167
+ "<patch_index_0158>": 32173,
168
+ "<patch_index_0159>": 32174,
169
+ "<patch_index_0160>": 32175,
170
+ "<patch_index_0161>": 32176,
171
+ "<patch_index_0162>": 32177,
172
+ "<patch_index_0163>": 32178,
173
+ "<patch_index_0164>": 32179,
174
+ "<patch_index_0165>": 32180,
175
+ "<patch_index_0166>": 32181,
176
+ "<patch_index_0167>": 32182,
177
+ "<patch_index_0168>": 32183,
178
+ "<patch_index_0169>": 32184,
179
+ "<patch_index_0170>": 32185,
180
+ "<patch_index_0171>": 32186,
181
+ "<patch_index_0172>": 32187,
182
+ "<patch_index_0173>": 32188,
183
+ "<patch_index_0174>": 32189,
184
+ "<patch_index_0175>": 32190,
185
+ "<patch_index_0176>": 32191,
186
+ "<patch_index_0177>": 32192,
187
+ "<patch_index_0178>": 32193,
188
+ "<patch_index_0179>": 32194,
189
+ "<patch_index_0180>": 32195,
190
+ "<patch_index_0181>": 32196,
191
+ "<patch_index_0182>": 32197,
192
+ "<patch_index_0183>": 32198,
193
+ "<patch_index_0184>": 32199,
194
+ "<patch_index_0185>": 32200,
195
+ "<patch_index_0186>": 32201,
196
+ "<patch_index_0187>": 32202,
197
+ "<patch_index_0188>": 32203,
198
+ "<patch_index_0189>": 32204,
199
+ "<patch_index_0190>": 32205,
200
+ "<patch_index_0191>": 32206,
201
+ "<patch_index_0192>": 32207,
202
+ "<patch_index_0193>": 32208,
203
+ "<patch_index_0194>": 32209,
204
+ "<patch_index_0195>": 32210,
205
+ "<patch_index_0196>": 32211,
206
+ "<patch_index_0197>": 32212,
207
+ "<patch_index_0198>": 32213,
208
+ "<patch_index_0199>": 32214,
209
+ "<patch_index_0200>": 32215,
210
+ "<patch_index_0201>": 32216,
211
+ "<patch_index_0202>": 32217,
212
+ "<patch_index_0203>": 32218,
213
+ "<patch_index_0204>": 32219,
214
+ "<patch_index_0205>": 32220,
215
+ "<patch_index_0206>": 32221,
216
+ "<patch_index_0207>": 32222,
217
+ "<patch_index_0208>": 32223,
218
+ "<patch_index_0209>": 32224,
219
+ "<patch_index_0210>": 32225,
220
+ "<patch_index_0211>": 32226,
221
+ "<patch_index_0212>": 32227,
222
+ "<patch_index_0213>": 32228,
223
+ "<patch_index_0214>": 32229,
224
+ "<patch_index_0215>": 32230,
225
+ "<patch_index_0216>": 32231,
226
+ "<patch_index_0217>": 32232,
227
+ "<patch_index_0218>": 32233,
228
+ "<patch_index_0219>": 32234,
229
+ "<patch_index_0220>": 32235,
230
+ "<patch_index_0221>": 32236,
231
+ "<patch_index_0222>": 32237,
232
+ "<patch_index_0223>": 32238,
233
+ "<patch_index_0224>": 32239,
234
+ "<patch_index_0225>": 32240,
235
+ "<patch_index_0226>": 32241,
236
+ "<patch_index_0227>": 32242,
237
+ "<patch_index_0228>": 32243,
238
+ "<patch_index_0229>": 32244,
239
+ "<patch_index_0230>": 32245,
240
+ "<patch_index_0231>": 32246,
241
+ "<patch_index_0232>": 32247,
242
+ "<patch_index_0233>": 32248,
243
+ "<patch_index_0234>": 32249,
244
+ "<patch_index_0235>": 32250,
245
+ "<patch_index_0236>": 32251,
246
+ "<patch_index_0237>": 32252,
247
+ "<patch_index_0238>": 32253,
248
+ "<patch_index_0239>": 32254,
249
+ "<patch_index_0240>": 32255,
250
+ "<patch_index_0241>": 32256,
251
+ "<patch_index_0242>": 32257,
252
+ "<patch_index_0243>": 32258,
253
+ "<patch_index_0244>": 32259,
254
+ "<patch_index_0245>": 32260,
255
+ "<patch_index_0246>": 32261,
256
+ "<patch_index_0247>": 32262,
257
+ "<patch_index_0248>": 32263,
258
+ "<patch_index_0249>": 32264,
259
+ "<patch_index_0250>": 32265,
260
+ "<patch_index_0251>": 32266,
261
+ "<patch_index_0252>": 32267,
262
+ "<patch_index_0253>": 32268,
263
+ "<patch_index_0254>": 32269,
264
+ "<patch_index_0255>": 32270,
265
+ "<patch_index_0256>": 32271,
266
+ "<phrase>": 32009,
267
+ "[/IMG]": 32002,
268
+ "[/gIMG]": 32005,
269
+ "[EOC]": 32006,
270
+ "[IMG]": 32001,
271
+ "[PAD]": 32000,
272
+ "[VIDEO]": 32007,
273
+ "[gIMG]": 32004
274
+ }
tokenizer/special_tokens_map.json ADDED
@@ -0,0 +1,285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "additional_special_tokens": [
3
+ "[IMG]",
4
+ "[/IMG]",
5
+ "<image>",
6
+ "[gIMG]",
7
+ "[/gIMG]",
8
+ "[EOC]",
9
+ "[VIDEO]",
10
+ "<grounding>",
11
+ "<phrase>",
12
+ "</phrase>",
13
+ "<object>",
14
+ "</object>",
15
+ "</delimiter_of_multi_objects/>",
16
+ "<REC>",
17
+ "<patch_index_0000>",
18
+ "<patch_index_0001>",
19
+ "<patch_index_0002>",
20
+ "<patch_index_0003>",
21
+ "<patch_index_0004>",
22
+ "<patch_index_0005>",
23
+ "<patch_index_0006>",
24
+ "<patch_index_0007>",
25
+ "<patch_index_0008>",
26
+ "<patch_index_0009>",
27
+ "<patch_index_0010>",
28
+ "<patch_index_0011>",
29
+ "<patch_index_0012>",
30
+ "<patch_index_0013>",
31
+ "<patch_index_0014>",
32
+ "<patch_index_0015>",
33
+ "<patch_index_0016>",
34
+ "<patch_index_0017>",
35
+ "<patch_index_0018>",
36
+ "<patch_index_0019>",
37
+ "<patch_index_0020>",
38
+ "<patch_index_0021>",
39
+ "<patch_index_0022>",
40
+ "<patch_index_0023>",
41
+ "<patch_index_0024>",
42
+ "<patch_index_0025>",
43
+ "<patch_index_0026>",
44
+ "<patch_index_0027>",
45
+ "<patch_index_0028>",
46
+ "<patch_index_0029>",
47
+ "<patch_index_0030>",
48
+ "<patch_index_0031>",
49
+ "<patch_index_0032>",
50
+ "<patch_index_0033>",
51
+ "<patch_index_0034>",
52
+ "<patch_index_0035>",
53
+ "<patch_index_0036>",
54
+ "<patch_index_0037>",
55
+ "<patch_index_0038>",
56
+ "<patch_index_0039>",
57
+ "<patch_index_0040>",
58
+ "<patch_index_0041>",
59
+ "<patch_index_0042>",
60
+ "<patch_index_0043>",
61
+ "<patch_index_0044>",
62
+ "<patch_index_0045>",
63
+ "<patch_index_0046>",
64
+ "<patch_index_0047>",
65
+ "<patch_index_0048>",
66
+ "<patch_index_0049>",
67
+ "<patch_index_0050>",
68
+ "<patch_index_0051>",
69
+ "<patch_index_0052>",
70
+ "<patch_index_0053>",
71
+ "<patch_index_0054>",
72
+ "<patch_index_0055>",
73
+ "<patch_index_0056>",
74
+ "<patch_index_0057>",
75
+ "<patch_index_0058>",
76
+ "<patch_index_0059>",
77
+ "<patch_index_0060>",
78
+ "<patch_index_0061>",
79
+ "<patch_index_0062>",
80
+ "<patch_index_0063>",
81
+ "<patch_index_0064>",
82
+ "<patch_index_0065>",
83
+ "<patch_index_0066>",
84
+ "<patch_index_0067>",
85
+ "<patch_index_0068>",
86
+ "<patch_index_0069>",
87
+ "<patch_index_0070>",
88
+ "<patch_index_0071>",
89
+ "<patch_index_0072>",
90
+ "<patch_index_0073>",
91
+ "<patch_index_0074>",
92
+ "<patch_index_0075>",
93
+ "<patch_index_0076>",
94
+ "<patch_index_0077>",
95
+ "<patch_index_0078>",
96
+ "<patch_index_0079>",
97
+ "<patch_index_0080>",
98
+ "<patch_index_0081>",
99
+ "<patch_index_0082>",
100
+ "<patch_index_0083>",
101
+ "<patch_index_0084>",
102
+ "<patch_index_0085>",
103
+ "<patch_index_0086>",
104
+ "<patch_index_0087>",
105
+ "<patch_index_0088>",
106
+ "<patch_index_0089>",
107
+ "<patch_index_0090>",
108
+ "<patch_index_0091>",
109
+ "<patch_index_0092>",
110
+ "<patch_index_0093>",
111
+ "<patch_index_0094>",
112
+ "<patch_index_0095>",
113
+ "<patch_index_0096>",
114
+ "<patch_index_0097>",
115
+ "<patch_index_0098>",
116
+ "<patch_index_0099>",
117
+ "<patch_index_0100>",
118
+ "<patch_index_0101>",
119
+ "<patch_index_0102>",
120
+ "<patch_index_0103>",
121
+ "<patch_index_0104>",
122
+ "<patch_index_0105>",
123
+ "<patch_index_0106>",
124
+ "<patch_index_0107>",
125
+ "<patch_index_0108>",
126
+ "<patch_index_0109>",
127
+ "<patch_index_0110>",
128
+ "<patch_index_0111>",
129
+ "<patch_index_0112>",
130
+ "<patch_index_0113>",
131
+ "<patch_index_0114>",
132
+ "<patch_index_0115>",
133
+ "<patch_index_0116>",
134
+ "<patch_index_0117>",
135
+ "<patch_index_0118>",
136
+ "<patch_index_0119>",
137
+ "<patch_index_0120>",
138
+ "<patch_index_0121>",
139
+ "<patch_index_0122>",
140
+ "<patch_index_0123>",
141
+ "<patch_index_0124>",
142
+ "<patch_index_0125>",
143
+ "<patch_index_0126>",
144
+ "<patch_index_0127>",
145
+ "<patch_index_0128>",
146
+ "<patch_index_0129>",
147
+ "<patch_index_0130>",
148
+ "<patch_index_0131>",
149
+ "<patch_index_0132>",
150
+ "<patch_index_0133>",
151
+ "<patch_index_0134>",
152
+ "<patch_index_0135>",
153
+ "<patch_index_0136>",
154
+ "<patch_index_0137>",
155
+ "<patch_index_0138>",
156
+ "<patch_index_0139>",
157
+ "<patch_index_0140>",
158
+ "<patch_index_0141>",
159
+ "<patch_index_0142>",
160
+ "<patch_index_0143>",
161
+ "<patch_index_0144>",
162
+ "<patch_index_0145>",
163
+ "<patch_index_0146>",
164
+ "<patch_index_0147>",
165
+ "<patch_index_0148>",
166
+ "<patch_index_0149>",
167
+ "<patch_index_0150>",
168
+ "<patch_index_0151>",
169
+ "<patch_index_0152>",
170
+ "<patch_index_0153>",
171
+ "<patch_index_0154>",
172
+ "<patch_index_0155>",
173
+ "<patch_index_0156>",
174
+ "<patch_index_0157>",
175
+ "<patch_index_0158>",
176
+ "<patch_index_0159>",
177
+ "<patch_index_0160>",
178
+ "<patch_index_0161>",
179
+ "<patch_index_0162>",
180
+ "<patch_index_0163>",
181
+ "<patch_index_0164>",
182
+ "<patch_index_0165>",
183
+ "<patch_index_0166>",
184
+ "<patch_index_0167>",
185
+ "<patch_index_0168>",
186
+ "<patch_index_0169>",
187
+ "<patch_index_0170>",
188
+ "<patch_index_0171>",
189
+ "<patch_index_0172>",
190
+ "<patch_index_0173>",
191
+ "<patch_index_0174>",
192
+ "<patch_index_0175>",
193
+ "<patch_index_0176>",
194
+ "<patch_index_0177>",
195
+ "<patch_index_0178>",
196
+ "<patch_index_0179>",
197
+ "<patch_index_0180>",
198
+ "<patch_index_0181>",
199
+ "<patch_index_0182>",
200
+ "<patch_index_0183>",
201
+ "<patch_index_0184>",
202
+ "<patch_index_0185>",
203
+ "<patch_index_0186>",
204
+ "<patch_index_0187>",
205
+ "<patch_index_0188>",
206
+ "<patch_index_0189>",
207
+ "<patch_index_0190>",
208
+ "<patch_index_0191>",
209
+ "<patch_index_0192>",
210
+ "<patch_index_0193>",
211
+ "<patch_index_0194>",
212
+ "<patch_index_0195>",
213
+ "<patch_index_0196>",
214
+ "<patch_index_0197>",
215
+ "<patch_index_0198>",
216
+ "<patch_index_0199>",
217
+ "<patch_index_0200>",
218
+ "<patch_index_0201>",
219
+ "<patch_index_0202>",
220
+ "<patch_index_0203>",
221
+ "<patch_index_0204>",
222
+ "<patch_index_0205>",
223
+ "<patch_index_0206>",
224
+ "<patch_index_0207>",
225
+ "<patch_index_0208>",
226
+ "<patch_index_0209>",
227
+ "<patch_index_0210>",
228
+ "<patch_index_0211>",
229
+ "<patch_index_0212>",
230
+ "<patch_index_0213>",
231
+ "<patch_index_0214>",
232
+ "<patch_index_0215>",
233
+ "<patch_index_0216>",
234
+ "<patch_index_0217>",
235
+ "<patch_index_0218>",
236
+ "<patch_index_0219>",
237
+ "<patch_index_0220>",
238
+ "<patch_index_0221>",
239
+ "<patch_index_0222>",
240
+ "<patch_index_0223>",
241
+ "<patch_index_0224>",
242
+ "<patch_index_0225>",
243
+ "<patch_index_0226>",
244
+ "<patch_index_0227>",
245
+ "<patch_index_0228>",
246
+ "<patch_index_0229>",
247
+ "<patch_index_0230>",
248
+ "<patch_index_0231>",
249
+ "<patch_index_0232>",
250
+ "<patch_index_0233>",
251
+ "<patch_index_0234>",
252
+ "<patch_index_0235>",
253
+ "<patch_index_0236>",
254
+ "<patch_index_0237>",
255
+ "<patch_index_0238>",
256
+ "<patch_index_0239>",
257
+ "<patch_index_0240>",
258
+ "<patch_index_0241>",
259
+ "<patch_index_0242>",
260
+ "<patch_index_0243>",
261
+ "<patch_index_0244>",
262
+ "<patch_index_0245>",
263
+ "<patch_index_0246>",
264
+ "<patch_index_0247>",
265
+ "<patch_index_0248>",
266
+ "<patch_index_0249>",
267
+ "<patch_index_0250>",
268
+ "<patch_index_0251>",
269
+ "<patch_index_0252>",
270
+ "<patch_index_0253>",
271
+ "<patch_index_0254>",
272
+ "<patch_index_0255>",
273
+ "<patch_index_0256>"
274
+ ],
275
+ "bos_token": "<s>",
276
+ "eos_token": "</s>",
277
+ "pad_token": "[PAD]",
278
+ "unk_token": {
279
+ "content": "<unk>",
280
+ "lstrip": false,
281
+ "normalized": true,
282
+ "rstrip": false,
283
+ "single_word": false
284
+ }
285
+ }
tokenizer/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer/tokenizer.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9e556afd44213b6bd1be2b850ebbbd98f5481437a8021afaf58ee7fb1818d347
3
+ size 499723
tokenizer/tokenizer_config.json ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": true,
3
+ "add_eos_token": false,
4
+ "bos_token": {
5
+ "__type": "AddedToken",
6
+ "content": "<s>",
7
+ "lstrip": false,
8
+ "normalized": true,
9
+ "rstrip": false,
10
+ "single_word": false
11
+ },
12
+ "clean_up_tokenization_spaces": false,
13
+ "eos_token": {
14
+ "__type": "AddedToken",
15
+ "content": "</s>",
16
+ "lstrip": false,
17
+ "normalized": true,
18
+ "rstrip": false,
19
+ "single_word": false
20
+ },
21
+ "legacy": true,
22
+ "model_max_length": 1000000000000000019884624838656,
23
+ "pad_token": null,
24
+ "sp_model_kwargs": {},
25
+ "tokenizer_class": "LlamaTokenizer",
26
+ "unk_token": {
27
+ "__type": "AddedToken",
28
+ "content": "<unk>",
29
+ "lstrip": false,
30
+ "normalized": true,
31
+ "rstrip": false,
32
+ "single_word": false
33
+ }
34
+ }
unet/config.json ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "UNet2DConditionModel",
3
+ "_diffusers_version": "0.21.2",
4
+ "_name_or_path": "/share/project/quansun/release_hf/Emu2-VisualGeneration/unet",
5
+ "act_fn": "silu",
6
+ "addition_embed_type": "text_time",
7
+ "addition_embed_type_num_heads": 64,
8
+ "addition_time_embed_dim": 256,
9
+ "attention_head_dim": [
10
+ 5,
11
+ 10,
12
+ 20
13
+ ],
14
+ "attention_type": "default",
15
+ "block_out_channels": [
16
+ 320,
17
+ 640,
18
+ 1280
19
+ ],
20
+ "center_input_sample": false,
21
+ "class_embed_type": null,
22
+ "class_embeddings_concat": false,
23
+ "conv_in_kernel": 3,
24
+ "conv_out_kernel": 3,
25
+ "cross_attention_dim": 1792,
26
+ "cross_attention_norm": null,
27
+ "down_block_types": [
28
+ "DownBlock2D",
29
+ "CrossAttnDownBlock2D",
30
+ "CrossAttnDownBlock2D"
31
+ ],
32
+ "downsample_padding": 1,
33
+ "dropout": 0.0,
34
+ "dual_cross_attention": false,
35
+ "encoder_hid_dim": null,
36
+ "encoder_hid_dim_type": null,
37
+ "flip_sin_to_cos": true,
38
+ "freq_shift": 0,
39
+ "in_channels": 4,
40
+ "layers_per_block": 2,
41
+ "mid_block_only_cross_attention": null,
42
+ "mid_block_scale_factor": 1,
43
+ "mid_block_type": "UNetMidBlock2DCrossAttn",
44
+ "norm_eps": 1e-05,
45
+ "norm_num_groups": 32,
46
+ "num_attention_heads": null,
47
+ "num_class_embeds": null,
48
+ "only_cross_attention": false,
49
+ "out_channels": 4,
50
+ "projection_class_embeddings_input_dim": 3328,
51
+ "resnet_out_scale_factor": 1.0,
52
+ "resnet_skip_time_act": false,
53
+ "resnet_time_scale_shift": "default",
54
+ "sample_size": 128,
55
+ "time_cond_proj_dim": null,
56
+ "time_embedding_act_fn": null,
57
+ "time_embedding_dim": null,
58
+ "time_embedding_type": "positional",
59
+ "timestep_post_act": null,
60
+ "transformer_layers_per_block": [
61
+ 1,
62
+ 2,
63
+ 10
64
+ ],
65
+ "up_block_types": [
66
+ "CrossAttnUpBlock2D",
67
+ "CrossAttnUpBlock2D",
68
+ "UpBlock2D"
69
+ ],
70
+ "upcast_attention": null,
71
+ "use_linear_projection": true
72
+ }
unet/diffusion_pytorch_model.bf16.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:67215fe9e8e24202651fce2ff72203d21bdb7986a88ec062f72cc94f6040a314
3
+ size 5051265352
vae/config.json ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "AutoencoderKL",
3
+ "_diffusers_version": "0.21.2",
4
+ "_name_or_path": "/share/project/quansun/release_hf/Emu2-VisualGeneration/vae",
5
+ "act_fn": "silu",
6
+ "block_out_channels": [
7
+ 128,
8
+ 256,
9
+ 512,
10
+ 512
11
+ ],
12
+ "down_block_types": [
13
+ "DownEncoderBlock2D",
14
+ "DownEncoderBlock2D",
15
+ "DownEncoderBlock2D",
16
+ "DownEncoderBlock2D"
17
+ ],
18
+ "force_upcast": true,
19
+ "in_channels": 3,
20
+ "latent_channels": 4,
21
+ "layers_per_block": 2,
22
+ "norm_num_groups": 32,
23
+ "out_channels": 3,
24
+ "sample_size": 1024,
25
+ "scaling_factor": 0.13025,
26
+ "up_block_types": [
27
+ "UpDecoderBlock2D",
28
+ "UpDecoderBlock2D",
29
+ "UpDecoderBlock2D",
30
+ "UpDecoderBlock2D"
31
+ ]
32
+ }
vae/diffusion_pytorch_model.bf16.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2741af7e84fe3b0a7aee02f89fa34c0858ed55f5782aab5931b94938983652da
3
+ size 167335590