stefan-insilico commited on
Commit
6887a13
1 Parent(s): 565518e

Model weights

Browse files
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ tokenizer_config.json filter=lfs diff=lfs merge=lfs -text
37
+ tokenizer.json filter=lfs diff=lfs merge=lfs -text
config.json ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "Custom_MPTForCausalLM"
4
+ ],
5
+ "attn_config": {
6
+ "alibi": true,
7
+ "alibi_bias_max": 8,
8
+ "attn_impl": "torch",
9
+ "attn_pdrop": 0,
10
+ "attn_type": "multihead_attention",
11
+ "attn_uses_sequence_id": false,
12
+ "clip_qkv": null,
13
+ "prefix_lm": false,
14
+ "qk_gn": false,
15
+ "qk_ln": false,
16
+ "rope": false,
17
+ "rope_dail_config": {
18
+ "pos_idx_in_fp32": true,
19
+ "type": "original",
20
+ "xpos_scale_base": 512
21
+ },
22
+ "rope_hf_config": {
23
+ "factor": 1.0,
24
+ "type": "no_scaling"
25
+ },
26
+ "rope_impl": "dail",
27
+ "rope_theta": 10000,
28
+ "sliding_window_size": -1,
29
+ "softmax_scale": null
30
+ },
31
+ "auto_map": {
32
+ "AutoConfig": "mpt-7b--configuration_mpt.MPTConfig",
33
+ "AutoModelForCausalLM": "mpt-7b--modeling_mpt.MPTForCausalLM"
34
+ },
35
+ "bos_token_id": 0,
36
+ "d_model": 360,
37
+ "emb_pdrop": 0,
38
+ "embedding_fraction": 1.0,
39
+ "eos_token_id": 1,
40
+ "expansion_ratio": 5,
41
+ "fc_type": "torch",
42
+ "ffn_config": {
43
+ "fc_type": "torch",
44
+ "ffn_type": "mptmlp"
45
+ },
46
+ "init_config": {
47
+ "emb_init_std": null,
48
+ "emb_init_uniform_lim": null,
49
+ "fan_mode": "fan_in",
50
+ "init_div_is_residual": true,
51
+ "init_gain": 0,
52
+ "init_nonlinearity": "relu",
53
+ "init_std": 0.02,
54
+ "name": "kaiming_normal_",
55
+ "verbose": 0
56
+ },
57
+ "init_device": "cuda",
58
+ "learned_pos_emb": false,
59
+ "logit_scale": null,
60
+ "max_seq_len": 600,
61
+ "model_type": "mpt",
62
+ "n_heads": 36,
63
+ "n_layers": 36,
64
+ "no_bias": true,
65
+ "norm_type": "low_precision_layernorm",
66
+ "resid_pdrop": 0,
67
+ "torch_dtype": "bfloat16",
68
+ "transformers_version": "4.35.0",
69
+ "use_cache": false,
70
+ "use_pad_tok_in_ffn": true,
71
+ "verbose": 0,
72
+ "vocab_size": 63740
73
+ }
handler.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any
2
+ import os
3
+ import torch
4
+ from transformers import AutoTokenizer, AutoModelForCausalLM
5
+ from transformers import PreTrainedTokenizerFast
6
+ from transformers import GenerationConfig
7
+ import transformers
8
+ import pandas as pd
9
+ import time
10
+ from precious3_gpt_multi_model import Custom_MPTForCausalLM
11
+
12
+
13
+ emb_gpt_genes = pd.read_pickle('./multi-modal-data/emb_gpt_genes.pickle')
14
+ emb_hgt_genes = pd.read_pickle('./multi-modal-data/emb_hgt_genes.pickle')
15
+
16
+
17
+ def create_prompt(prompt_config):
18
+
19
+ prompt = "[BOS]"
20
+
21
+ multi_modal_prefix = '<modality0><modality1><modality2><modality3>'*3
22
+
23
+ for k, v in prompt_config.items():
24
+ if k=='instruction':
25
+ prompt+=f"<{v}>"
26
+ elif k=='up':
27
+ prompt+=f'{multi_modal_prefix}<{k}>{v}</{k}>' if isinstance(v, str) else f'{multi_modal_prefix}<{k}>{" ".join(v)} </{k}>'
28
+ elif k=='down':
29
+ prompt+=f'{multi_modal_prefix}<{k}>{v}</{k}>' if isinstance(v, str) else f'{multi_modal_prefix}<{k}>{" ".join(v)} </{k}>'
30
+ else:
31
+ prompt+=f'<{k}>{v}</{k}>' if isinstance(v, str) else f'<{k}>{" ".join(v)} </{k}>'
32
+ return prompt
33
+
34
+ def custom_generate(input_ids,
35
+ acc_embs_up_kg_mean,
36
+ acc_embs_down_kg_mean,
37
+ acc_embs_up_txt_mean,
38
+ acc_embs_down_txt_mean,
39
+ device,
40
+ max_new_tokens,
41
+ num_return_sequences,
42
+ temperature=0.8,
43
+ top_p=0.2, top_k=3550, n_next_tokens=50,
44
+ unique_compounds):
45
+ torch.manual_seed(137)
46
+
47
+ # Set parameters
48
+ # temperature - Higher value for more randomness, lower for more control
49
+ # top_p - Probability threshold for nucleus sampling (aka top-p sampling)
50
+ # top_k - Ignore logits below the top-k value to reduce randomness (if non-zero)
51
+ # n_next_tokens - Number of top next tokens when predicting compounds
52
+
53
+ modality0_emb = torch.unsqueeze(torch.from_numpy(acc_embs_up_kg_mean), 0).to(device) # torch.from_numpy(efo_embeddings['EFO_0002618']).type(torch.bfloat16).to(device)
54
+ modality1_emb = torch.unsqueeze(torch.from_numpy(acc_embs_down_kg_mean), 0).to(device)
55
+ modality2_emb = torch.unsqueeze(torch.from_numpy(acc_embs_up_txt_mean), 0).to(device) # torch.from_numpy(efo_embeddings['EFO_0002618']).type(torch.bfloat16).to(device)
56
+ modality3_emb = torch.unsqueeze(torch.from_numpy(acc_embs_down_txt_mean), 0).to(device)
57
+
58
+
59
+ # Generate sequences
60
+ outputs = []
61
+ next_token_compounds = []
62
+
63
+ for _ in range(num_return_sequences):
64
+ start_time = time.time()
65
+ generated_sequence = []
66
+ current_token = input_ids.clone()
67
+
68
+ for _ in range(max_new_tokens): # Maximum length of generated sequence
69
+ # Forward pass through the model
70
+ logits = model.forward(input_ids=current_token,
71
+ modality0_emb=modality0_emb, # torch.tensor(efo_embeddings['EFO_0002618'], dtype=torch.bfloat16).to(device),
72
+ modality0_token_id=62191,
73
+ modality1_emb=modality1_emb, # torch.tensor(efo_embeddings['EFO_0002618'], dtype=torch.bfloat16).to(device),
74
+ modality1_token_id=62192,
75
+ modality2_emb=modality2_emb, # torch.tensor(efo_embeddings['EFO_0002618'], dtype=torch.bfloat16).to(device),
76
+ modality2_token_id=62193,
77
+ modality3_emb=modality3_emb, # torch.tensor(efo_embeddings['EFO_0002618'], dtype=torch.bfloat16).to(device),
78
+ modality3_token_id=62194)[0]
79
+
80
+ # Apply temperature to logits
81
+ if temperature != 1.0:
82
+ logits = logits / temperature
83
+
84
+ # Apply top-p sampling (nucleus sampling)
85
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
86
+ cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
87
+ sorted_indices_to_remove = cumulative_probs > top_p
88
+
89
+ if top_k > 0:
90
+ sorted_indices_to_remove[..., top_k:] = 1
91
+
92
+ # Set the logit values of the removed indices to a very small negative value
93
+ inf_tensor = torch.tensor(float("-inf")).type(torch.bfloat16).to(logits.device)
94
+
95
+ logits = logits.where(sorted_indices_to_remove, inf_tensor)
96
+
97
+
98
+ # Sample the next token
99
+ if current_token[0][-1] == tokenizer.encode('<drug>')[0]:
100
+ next_token_compounds.append(torch.topk(torch.softmax(logits, dim=-1)[0][len(current_token[0])-1, :].flatten(), 50).indices)
101
+
102
+ next_token = torch.multinomial(torch.softmax(logits, dim=-1)[0], num_samples=1)[len(current_token[0])-1, :].unsqueeze(0)
103
+
104
+
105
+ # Append the sampled token to the generated sequence
106
+ generated_sequence.append(next_token.item())
107
+
108
+ Stop generation if an end token is generated
109
+ if next_token == tokenizer.eos_token_id:
110
+ break
111
+
112
+ # Prepare input for the next iteration
113
+ current_token = torch.cat((current_token, next_token), dim=-1)
114
+ print(time.time()-start_time)
115
+ outputs.append(generated_sequence)
116
+ return outputs, next_token_compounds
117
+
118
+
119
+ def get_predicted_compounds(input_ids, generation_output, tokenizer, p3_compounds):
120
+ id_4_drug_token = list(generation_output.sequences[0][len(input_ids[0]):]).index(tokenizer.convert_tokens_to_ids(['<drug>'])[0])
121
+ id_4_drug_token += 1
122
+ print('This is token index where drug should be predicted: ', id_4_drug_token)
123
+
124
+ values, indices = torch.topk(generation_output["scores"][id_4_drug_token].view(-1), k=50)
125
+ indices_decoded = tokenizer.decode(indices, skip_special_tokens=True)
126
+
127
+ predicted_compound = indices_decoded.split(' ')
128
+ predicted_compound = [i.strip() for i in predicted_compound]
129
+
130
+ valid_compounds = sorted(set(predicted_compound) & set(p3_compounds), key = predicted_compound.index)
131
+ print(f"Model predicted {len(predicted_compound)} tokens. Valid compounds {len(valid_compounds)}")
132
+ return valid_compounds
133
+
134
+
135
+ class EndpointHandler:
136
+ def __init__(self, path=""):
137
+ # load model and processor from path
138
+ self.model = Custom_MPTForCausalLM.from_pretrained(path, torch_dtype=torch.bfloat16).to('cuda')
139
+ self.tokenizer = PreTrainedTokenizerFast(tokenizer_file = os.path.join(path, "tokenizer.json"), unk_token="[UNK]",
140
+ pad_token="[PAD]",
141
+ eos_token="[EOS]",
142
+ bos_token="[BOS]")
143
+ self.model.config.pad_token_id = self.tokenizer.pad_token_id
144
+ self.model.config.bos_token_id = self.tokenizer.bos_token_id
145
+ self.model.config.eos_token_id = self.tokenizer.eos_token_id
146
+ unique_entities_p3 = pd.read_csv(os.path.join(path, 'all_entities_with_type.csv'))
147
+ self.unique_compounds_p3 = [i.strip() for i in unique_entities_p3[unique_entities_p3.type=='compound'].entity.to_list()]
148
+
149
+
150
+
151
+ def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
152
+ """
153
+ Args:
154
+ data (:dict:):
155
+ The payload with the text prompt and generation parameters.
156
+ """
157
+
158
+ inputs = data.pop("inputs", data)
159
+ parameters = data.pop("parameters", None)
160
+ mode = data.pop('mode', 'diff2compound')
161
+
162
+ if mode == 'diff2compound':
163
+ with open('./generation-configs/diff2compound.json', 'r') as f:
164
+ config_data = json.load(f)
165
+ else:
166
+ with open('./generation-configs/diff2compound.json', 'r') as f:
167
+ config_data = json.load(f)
168
+
169
+ prompt = create_prompt(config_data)
170
+
171
+ inputs = self.tokenizer(inputs, return_tensors="pt")
172
+ input_ids = inputs["input_ids"].to('cuda')
173
+
174
+ ### Generation config https://huggingface.co/blog/how-to-generate
175
+ generation_config = GenerationConfig(**parameters,
176
+ pad_token_id=self.tokenizer.pad_token_id, num_return_sequences=1)
177
+
178
+ max_new_tokens = self.model.config.max_seq_len - len(input_ids[0]) # max_new_tokens = 560 - len(input_ids[0])
179
+
180
+ torch.manual_seed(137)
181
+
182
+ with torch.no_grad():
183
+ generation_output = self.model.generate(
184
+ input_ids=input_ids,
185
+ generation_config=generation_config,
186
+ return_dict_in_generate=True,
187
+ output_scores=True,
188
+ max_new_tokens=max_new_tokens
189
+ )
190
+ if mode =='diff2compound':
191
+ predicted_compounds = get_predicted_compounds(input_ids=input_ids, generation_output=generation_output, tokenizer=self.tokenizer, p3_compounds=self.unique_compounds_p3)
192
+ output = {'output': predicted_compounds, "mode": mode, 'message': "Done!"}
193
+ else:
194
+ output = {'output': [None], "mode": mode, 'message': "Set mode"}
195
+ return output
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:17e1167c39df0e2ac88e4267a13f7b0d4a43b48eb124d7cef8230e6d0e98e257
3
+ size 178841976
precious3_gpt_multi_modal.py ADDED
@@ -0,0 +1,340 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple, Union, List
2
+
3
+ from transformers.models.mpt.modeling_mpt import MptBlock, build_mpt_alibi_tensor
4
+ from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
5
+ import torch
6
+ import torch.nn as nn
7
+ from torch.nn import CrossEntropyLoss, LayerNorm
8
+ from transformers.models.mpt.modeling_mpt import MptBlock, build_mpt_alibi_tensor
9
+ from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions, CausalLMOutputWithPast, \
10
+ BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPast
11
+ # from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig, MptForCausalLM, MptModel
12
+ from transformers import PreTrainedTokenizerFast
13
+ import os
14
+ import torch.nn.functional as F
15
+
16
+ from mpt_7b.modeling_mpt import MPTModel, MPTForCausalLM, gen_attention_mask_in_length
17
+ from mpt_7b.configuration_mpt import MPTConfig
18
+ from mpt_7b.blocks import MPTBlock
19
+ from mpt_7b.norm import NORM_CLASS_REGISTRY
20
+ from mpt_7b.custom_embedding import SharedEmbedding
21
+ from mpt_7b.attention import ATTN_CLASS_REGISTRY, attn_bias_shape, build_attn_bias, gen_slopes
22
+
23
+ import logging
24
+ log = logging.getLogger(__name__)
25
+
26
+ class Custom_MptModel(MPTModel): # MptModel
27
+ def __init__(self, config: MPTConfig, modality0_dim=128, modality2_dim=1536):
28
+ config._validate_config()
29
+ super().__init__(config)
30
+ self.attn_impl = config.attn_config['attn_impl']
31
+ self.prefix_lm = config.attn_config['prefix_lm']
32
+ self.attn_uses_sequence_id = config.attn_config['attn_uses_sequence_id']
33
+ self.alibi = config.attn_config['alibi']
34
+ self.alibi_bias_max = config.attn_config['alibi_bias_max']
35
+ self.learned_pos_emb = config.learned_pos_emb
36
+ if config.init_device == 'mixed':
37
+ if dist.get_local_rank() == 0:
38
+ config.init_device = 'cpu'
39
+ else:
40
+ config.init_device = 'meta'
41
+ if config.norm_type.lower() not in NORM_CLASS_REGISTRY.keys():
42
+ norm_options = ' | '.join(NORM_CLASS_REGISTRY.keys())
43
+ raise NotImplementedError(f'Requested norm type ({config.norm_type}) is not implemented within this repo (Options: {norm_options}).')
44
+ norm_class = NORM_CLASS_REGISTRY[config.norm_type.lower()]
45
+ self.embedding_fraction = config.embedding_fraction
46
+ self.wte = SharedEmbedding(config.vocab_size, config.d_model, device=config.init_device)
47
+ if self.learned_pos_emb:
48
+ self.wpe = torch.nn.Embedding(config.max_seq_len, config.d_model, device=config.init_device)
49
+ self.emb_drop = nn.Dropout(config.emb_pdrop)
50
+ self.blocks = nn.ModuleList([MPTBlock(device=config.init_device, **config.to_dict()) for _ in range(config.n_layers)])
51
+ self.norm_f = norm_class(config.d_model, device=config.init_device)
52
+
53
+
54
+ ### Added for P3GPT - START
55
+ # Freeze all parameters except the projection layer
56
+ for param in self.wte.parameters():
57
+ param.requires_grad = False
58
+
59
+ for param in self.blocks.parameters():
60
+ param.requires_grad = False
61
+
62
+ # Add a projection layer for the custom embedding
63
+ # torch.set_default_dtype(torch.bfloat16)
64
+ self.modality0_embedding_projection = nn.ModuleList([nn.Linear(modality0_dim, config.d_model),
65
+ # nn.BatchNorm1d(config.d_model),
66
+ nn.ReLU(),
67
+ nn.Linear(config.d_model, config.d_model),
68
+ # nn.BatchNorm1d(config.d_model),
69
+ nn.ReLU(),
70
+ nn.Linear(config.d_model, config.d_model)])# nn.Linear(modality0_dim, self.hidden_size)
71
+
72
+
73
+ self.modality2_embedding_projection = nn.ModuleList([nn.Linear(modality2_dim, config.d_model),
74
+ # nn.BatchNorm1d(config.d_model),
75
+ nn.ReLU(),
76
+ nn.Linear(config.d_model, config.d_model),
77
+ # nn.BatchNorm1d(config.d_model),
78
+ nn.ReLU(),
79
+ nn.Linear(config.d_model, config.d_model)])# nn.Linear(modality0_dim, self.hidden_size)
80
+
81
+
82
+ ### Added for P3GPT - FINISH
83
+
84
+ self.rope = config.attn_config['rope']
85
+ self.rope_impl = None
86
+ if self.rope:
87
+ self.rope_impl = config.attn_config['rope_impl']
88
+ self.rotary_embedding = gen_rotary_embedding(rope_head_dim=config.d_model // config.n_heads, rope_impl=self.rope_impl, rope_theta=config.attn_config['rope_theta'], rope_dail_config=config.attn_config['rope_dail_config'], rope_hf_config=config.attn_config['rope_hf_config'], max_seq_len=self.config.max_seq_len)
89
+ if config.init_device != 'meta':
90
+ log.info(f'We recommend using config.init_device="meta" with Composer + FSDP for faster initialization.')
91
+ self.apply(self.param_init_fn)
92
+ self.is_causal = not self.prefix_lm
93
+ self._attn_bias_initialized = False
94
+ self.attn_bias = None
95
+ self.attn_bias_shape = attn_bias_shape(self.attn_impl, config.n_heads, config.max_seq_len, self.alibi, prefix_lm=self.prefix_lm, causal=self.is_causal, use_sequence_id=self.attn_uses_sequence_id)
96
+ if config.no_bias:
97
+ for module in self.modules():
98
+ if hasattr(module, 'bias') and isinstance(module.bias, nn.Parameter):
99
+ log.info(f'Removing bias from module={module!r}.')
100
+ module.register_parameter('bias', None)
101
+ if hasattr(module, 'use_bias'):
102
+ log.info(f'Setting use_bias=False for module={module!r}.')
103
+ module.use_bias = False
104
+ log.debug(self)
105
+ log.debug(f"Using {self.config.init_config['name']} initialization.")
106
+
107
+ # Initialize weights and apply final processing
108
+ # self.post_init()
109
+
110
+
111
+ def get_input_embeddings(self):
112
+ return self.wte
113
+
114
+
115
+ def set_input_embeddings(self, new_embeddings):
116
+ # self.wte = new_embeddings
117
+ self.wte.weight = new_embeddings
118
+
119
+
120
+ def forward(self, input_ids: Optional[torch.LongTensor]=None, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None,
121
+ attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None,
122
+ sequence_id: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None,
123
+ output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None,
124
+ inputs_embeds: Optional[torch.Tensor]=None, modality0_emb: Optional[bool] = None,
125
+ modality0_token_id: Optional[bool] = None, modality1_emb: Optional[bool] = None, modality1_token_id: Optional[bool] = None,
126
+ modality2_emb: Optional[bool] = None, modality2_token_id: Optional[bool] = None, modality3_emb: Optional[bool] = None,
127
+ modality3_token_id: Optional[bool] = None,) -> BaseModelOutputWithPast:
128
+
129
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
130
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
131
+ if attention_mask is not None:
132
+ attention_mask = attention_mask.bool()
133
+ if prefix_mask is not None:
134
+ prefix_mask = prefix_mask.bool()
135
+ if not return_dict:
136
+ raise NotImplementedError('return_dict False is not implemented yet for MPT')
137
+ if output_attentions:
138
+ if self.attn_impl != 'torch':
139
+ raise NotImplementedError('output_attentions is not implemented for MPT when using attn_impl `flash` or `triton`.')
140
+ if self.training and attention_mask is not None and (attention_mask[:, 0].sum() != attention_mask.shape[0]):
141
+ raise NotImplementedError('MPT does not support training with left padding.')
142
+ if self.prefix_lm and prefix_mask is None:
143
+ raise ValueError('prefix_mask is a required argument when MPT is configured with prefix_lm=True.')
144
+ if self.training:
145
+ if self.attn_uses_sequence_id and sequence_id is None:
146
+ raise ValueError('sequence_id is a required argument when MPT is configured with attn_uses_sequence_id=True ' + 'and the model is in train mode.')
147
+ elif self.attn_uses_sequence_id is False and sequence_id is not None:
148
+ warnings.warn('MPT received non-None input for `sequence_id` but is configured with attn_uses_sequence_id=False. ' + 'This input will be ignored. If you want the model to use `sequence_id`, set attn_uses_sequence_id to True.')
149
+
150
+ ### ADDED FOR P3 - START
151
+
152
+ if modality0_emb is not None:
153
+ modality0_emb = torch.tensor(modality0_emb, dtype=torch.bfloat16)
154
+ hidden_states = self.wte.weight.detach()
155
+
156
+ for layer in self.modality0_embedding_projection:
157
+ modality0_emb = layer(modality0_emb)
158
+ proj_modality0_emb = modality0_emb
159
+
160
+ # Replace the original embedding for the custom token with the custom embedding
161
+ hidden_states[modality0_token_id, :] = torch.mean(torch.squeeze(proj_modality0_emb, 1), dim=0)
162
+ self.set_input_embeddings(torch.nn.Parameter(hidden_states))
163
+
164
+ if modality1_emb is not None:
165
+ modality1_emb = torch.tensor(modality1_emb, dtype=torch.bfloat16)
166
+ hidden_states = self.wte.weight.detach()
167
+
168
+ for layer in self.modality0_embedding_projection:
169
+ modality1_emb = layer(modality1_emb)
170
+ proj_modality1_emb = modality1_emb
171
+
172
+ # Replace the original embedding for the custom token with the custom embedding
173
+ hidden_states[modality1_token_id, :] = torch.mean(torch.squeeze(proj_modality1_emb, 1), dim=0)
174
+ self.set_input_embeddings(torch.nn.Parameter(hidden_states))
175
+
176
+ if modality2_emb is not None:
177
+ modality2_emb = torch.tensor(modality2_emb, dtype=torch.bfloat16)
178
+ hidden_states = self.wte.weight.detach()
179
+
180
+ for layer in self.modality2_embedding_projection:
181
+ modality2_emb = layer(modality2_emb)
182
+ proj_modality2_emb = modality2_emb
183
+
184
+ # Replace the original embedding for the custom token with the custom embedding
185
+ hidden_states[modality2_token_id, :] = torch.mean(torch.squeeze(proj_modality2_emb, 1), dim=0)
186
+ self.set_input_embeddings(torch.nn.Parameter(hidden_states))
187
+
188
+ if modality3_emb is not None:
189
+ modality3_emb = torch.tensor(modality3_emb, dtype=torch.bfloat16)
190
+ hidden_states = self.wte.weight.detach()
191
+
192
+ for layer in self.modality2_embedding_projection:
193
+ modality3_emb = layer(modality3_emb)
194
+ proj_modality3_emb = modality3_emb
195
+
196
+ # Replace the original embedding for the custom token with the custom embedding
197
+ hidden_states[modality3_token_id, :] = torch.mean(torch.squeeze(proj_modality3_emb, 1), dim=0)
198
+ self.set_input_embeddings(torch.nn.Parameter(hidden_states))
199
+
200
+ ### ADDED FOR P3 - END
201
+
202
+ if input_ids is not None and inputs_embeds is not None:
203
+ raise ValueError('You cannot specify both input_ids and inputs_embeds.')
204
+ elif input_ids is not None:
205
+ bsz = input_ids.size(0)
206
+ S = input_ids.size(1)
207
+ x = self.wte(input_ids)
208
+ input_device = input_ids.device
209
+ elif inputs_embeds is not None:
210
+ bsz = inputs_embeds.size(0)
211
+ S = inputs_embeds.size(1)
212
+ x = inputs_embeds
213
+ input_device = inputs_embeds.device
214
+ else:
215
+ raise ValueError('You must specify input_ids or inputs_embeds')
216
+ assert S <= self.config.max_seq_len, f'Cannot forward input with seq_len={S}, this model only supports seq_len<={self.config.max_seq_len}'
217
+ rotary_emb_w_meta_info = None
218
+ past_position = 0
219
+ if past_key_values is not None:
220
+ if len(past_key_values) != self.config.n_layers:
221
+ raise ValueError(f'past_key_values must provide a past_key_value for each attention ' + f'layer in the network (len(past_key_values)={len(past_key_values)!r}; self.config.n_layers={self.config.n_layers!r}).')
222
+ past_position = past_key_values[0][0].size(1)
223
+ if self.attn_impl == 'torch':
224
+ past_position = past_key_values[0][0].size(3)
225
+ if self.learned_pos_emb or self.rope:
226
+ if self.learned_pos_emb and S + past_position > self.config.max_seq_len:
227
+ raise ValueError(f'Cannot forward input with past sequence length {past_position} and current sequence length ' + f'{S + 1}, this model only supports total sequence length <= {self.config.max_seq_len}.')
228
+ if self.learned_pos_emb or (self.rope and self.rope_impl == 'hf'):
229
+ pos = torch.arange(past_position, S + past_position, dtype=torch.long, device=input_device).unsqueeze(0)
230
+ if attention_mask is not None:
231
+ pos = torch.clamp(pos - torch.cumsum((~attention_mask).to(torch.int32), dim=1)[:, past_position:], min=0)
232
+ if self.learned_pos_emb:
233
+ x = x + self.wpe(pos)
234
+ elif self.rope and self.rope_impl == 'hf':
235
+ rotary_emb_w_meta_info = {'impl': self.rope_impl, 'rotary_emb': self.rotary_embedding, 'offset_info': pos, 'seq_len': S + past_position}
236
+ elif self.rope and self.rope_impl == 'dail':
237
+ rotary_emb_w_meta_info = {'impl': self.rope_impl, 'rotary_emb': self.rotary_embedding, 'offset_info': past_position, 'seq_len': S + past_position}
238
+ if self.embedding_fraction == 1:
239
+ x = self.emb_drop(x)
240
+ else:
241
+ x_shrunk = x * self.embedding_fraction + x.detach() * (1 - self.embedding_fraction)
242
+ assert isinstance(self.emb_drop, nn.Module)
243
+ x = self.emb_drop(x_shrunk)
244
+ (attn_bias, attention_mask) = self._attn_bias(device=x.device, dtype=torch.float32, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id)
245
+ attention_mask_in_length = gen_attention_mask_in_length(sequence_id=sequence_id, S=S, attn_uses_sequence_id=self.attn_uses_sequence_id, attn_impl=self.attn_impl, attention_mask=attention_mask)
246
+ alibi_slopes = None
247
+ if self.alibi and self.attn_impl == 'flash':
248
+ alibi_slopes = gen_slopes(n_heads=self.config.n_heads, alibi_bias_max=self.alibi_bias_max, device=x.device, return_1d=True)
249
+
250
+ presents = () if use_cache else None
251
+ if use_cache and past_key_values is None:
252
+ past_key_values = [() for _ in range(self.config.n_layers)]
253
+ all_hidden_states = () if output_hidden_states else None
254
+ all_self_attns = () if output_attentions else None
255
+ flash_attn_padding_info = {}
256
+ if self.attn_impl == 'flash':
257
+ flash_attn_padding_info = gen_flash_attn_padding_info(bsz, S, past_position, x.device, attention_mask_in_length, attention_mask)
258
+ for (b_idx, block) in enumerate(self.blocks):
259
+ if output_hidden_states:
260
+ assert all_hidden_states is not None
261
+ all_hidden_states = all_hidden_states + (x,)
262
+ past_key_value = past_key_values[b_idx] if past_key_values is not None else None
263
+ (x, attn_weights, present) = block(x, past_key_value=past_key_value, attn_bias=attn_bias, rotary_emb_w_meta_info=rotary_emb_w_meta_info, attention_mask=attention_mask, is_causal=self.is_causal, output_attentions=bool(output_attentions), alibi_slopes=alibi_slopes, flash_attn_padding_info=flash_attn_padding_info)
264
+ if presents is not None:
265
+ presents += (present,)
266
+ if output_attentions:
267
+ assert all_self_attns is not None
268
+ all_self_attns = all_self_attns + (attn_weights,)
269
+ x = self.norm_f(x)
270
+ if output_hidden_states:
271
+ assert all_hidden_states is not None
272
+ all_hidden_states = all_hidden_states + (x,)
273
+ return BaseModelOutputWithPast(last_hidden_state=x, past_key_values=presents, hidden_states=all_hidden_states, attentions=all_self_attns)
274
+
275
+
276
+ class Custom_MPTForCausalLM(MPTForCausalLM):
277
+
278
+ def __init__(self, config: MPTConfig):
279
+ super().__init__(config)
280
+ # log.info(f'Instantiating an MPTForCausalLM model from {__file__}')
281
+ self.transformer: MPTModel = Custom_MptModel(config)
282
+ self.lm_head = None
283
+ if not config.tie_word_embeddings:
284
+ self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False, device=config.init_device)
285
+ self.lm_head._fsdp_wrap = True
286
+ for child in self.transformer.children():
287
+ if isinstance(child, torch.nn.ModuleList):
288
+ continue
289
+ if isinstance(child, torch.nn.Module):
290
+ child._fsdp_wrap = True
291
+ self.logit_scale = None
292
+ if config.logit_scale is not None:
293
+ logit_scale = config.logit_scale
294
+ if isinstance(logit_scale, str):
295
+ if logit_scale == 'inv_sqrt_d_model':
296
+ logit_scale = 1 / math.sqrt(config.d_model)
297
+ else:
298
+ raise ValueError(f"logit_scale={logit_scale!r} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'.")
299
+ self.logit_scale = logit_scale
300
+
301
+ def forward(self, input_ids: Optional[torch.LongTensor]=None, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None,
302
+ attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None,
303
+ sequence_id: Optional[torch.LongTensor]=None, labels: Optional[torch.LongTensor]=None,
304
+ return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None,
305
+ use_cache: Optional[bool]=None, inputs_embeds: Optional[torch.FloatTensor]=None,
306
+ modality0_emb: Optional[bool] = None, modality0_token_id: Optional[bool] = None,
307
+ modality1_emb: Optional[bool] = None, modality1_token_id: Optional[bool] = None,
308
+ modality2_emb: Optional[bool] = None, modality2_token_id: Optional[bool] = None,
309
+ modality3_emb: Optional[bool] = None, modality3_token_id: Optional[bool] = None) -> CausalLMOutputWithPast:
310
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
311
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
312
+ outputs = self.transformer(
313
+ input_ids=input_ids, past_key_values=past_key_values, attention_mask=attention_mask, prefix_mask=prefix_mask,
314
+ sequence_id=sequence_id, return_dict=return_dict, output_attentions=output_attentions, output_hidden_states=output_hidden_states,
315
+ use_cache=use_cache, inputs_embeds=inputs_embeds,
316
+ modality0_emb=modality0_emb,
317
+ modality0_token_id=modality0_token_id,
318
+ modality1_emb=modality1_emb,
319
+ modality1_token_id=modality1_token_id,
320
+ modality2_emb=modality2_emb,
321
+ modality2_token_id=modality2_token_id,
322
+ modality3_emb=modality3_emb,
323
+ modality3_token_id=modality3_token_id
324
+ )
325
+ if self.lm_head is not None:
326
+ logits = self.lm_head(outputs.last_hidden_state)
327
+ else:
328
+ out = outputs.last_hidden_state
329
+ out = out.to(self.transformer.wte.weight.device)
330
+ logits = self.transformer.wte(out, True)
331
+ if self.logit_scale is not None:
332
+ if self.logit_scale == 0:
333
+ warnings.warn(f'Multiplying logits by self.logit_scale={self.logit_scale!r}. This will produce uniform (uninformative) outputs.')
334
+ logits *= self.logit_scale
335
+ loss = None
336
+ if labels is not None:
337
+ _labels = torch.roll(labels, shifts=-1)
338
+ _labels[:, -1] = -100
339
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), _labels.to(logits.device).view(-1))
340
+ return CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions)
special_tokens_map.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": "[BOS]",
3
+ "eos_token": "[EOS]",
4
+ "pad_token": "[PAD]",
5
+ "unk_token": "[UNK]"
6
+ }
tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ea99402688e989d7fe75a55513c21cdfea22158a76765e99a102df307ff5ea5e
3
+ size 12308399
tokenizer_config.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:80f8520546b55cf3bc43997f06ffcd15aa71887b6fce7e6701bac6c0d9ff55d6
3
+ size 11670857