stefan-insilico commited on
Commit
91696a4
·
verified ·
1 Parent(s): de1d205

Replaced next-token-generation with top-k-generation for signatures generation

Browse files
Files changed (1) hide show
  1. precious3_gpt_multi_modal.py +258 -131
precious3_gpt_multi_modal.py CHANGED
@@ -1,17 +1,19 @@
1
  from typing import Optional, Tuple, Union, List
2
 
3
- from transformers.models.mpt.modeling_mpt import MptBlock, build_mpt_alibi_tensor
4
- from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
 
 
5
  import torch
6
  import torch.nn as nn
7
- from torch.nn import CrossEntropyLoss, LayerNorm
8
- from transformers.models.mpt.modeling_mpt import MptBlock, build_mpt_alibi_tensor
9
- from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions, CausalLMOutputWithPast, \
10
- BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPast
11
- # from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig, MptForCausalLM, MptModel
12
- from transformers import PreTrainedTokenizerFast
13
- import os
14
  import torch.nn.functional as F
 
 
 
 
 
 
 
15
 
16
  from mpt_7b.modeling_mpt import MPTModel, MPTForCausalLM, gen_attention_mask_in_length
17
  from mpt_7b.configuration_mpt import MPTConfig
@@ -20,85 +22,96 @@ from mpt_7b.norm import NORM_CLASS_REGISTRY
20
  from mpt_7b.custom_embedding import SharedEmbedding
21
  from mpt_7b.attention import ATTN_CLASS_REGISTRY, attn_bias_shape, build_attn_bias, gen_slopes
22
 
23
- import logging
24
  log = logging.getLogger(__name__)
25
 
26
 
27
- class Custom_MPTConfig(MPTConfig):
28
- def __init__(self):
29
- super().__init__()
 
30
 
 
 
 
 
 
31
 
32
- class Custom_MptModel(MPTModel): # MptModel
33
- def __init__(self, config: MPTConfig, modality0_dim=128, modality2_dim=1536):
34
  config._validate_config()
35
  super().__init__(config)
 
 
36
  self.attn_impl = config.attn_config['attn_impl']
37
  self.prefix_lm = config.attn_config['prefix_lm']
38
  self.attn_uses_sequence_id = config.attn_config['attn_uses_sequence_id']
39
  self.alibi = config.attn_config['alibi']
40
  self.alibi_bias_max = config.attn_config['alibi_bias_max']
41
  self.learned_pos_emb = config.learned_pos_emb
 
 
42
  if config.init_device == 'mixed':
43
  if dist.get_local_rank() == 0:
44
  config.init_device = 'cpu'
45
  else:
46
  config.init_device = 'meta'
 
47
  if config.norm_type.lower() not in NORM_CLASS_REGISTRY.keys():
48
  norm_options = ' | '.join(NORM_CLASS_REGISTRY.keys())
49
  raise NotImplementedError(f'Requested norm type ({config.norm_type}) is not implemented within this repo (Options: {norm_options}).')
 
50
  norm_class = NORM_CLASS_REGISTRY[config.norm_type.lower()]
51
  self.embedding_fraction = config.embedding_fraction
 
 
52
  self.wte = SharedEmbedding(config.vocab_size, config.d_model, device=config.init_device)
 
53
  if self.learned_pos_emb:
54
- self.wpe = torch.nn.Embedding(config.max_seq_len, config.d_model, device=config.init_device)
 
55
  self.emb_drop = nn.Dropout(config.emb_pdrop)
 
 
56
  self.blocks = nn.ModuleList([MPTBlock(device=config.init_device, **config.to_dict()) for _ in range(config.n_layers)])
57
  self.norm_f = norm_class(config.d_model, device=config.init_device)
58
-
59
-
60
- ### Added for P3GPT - START
61
  # Freeze all parameters except the projection layer
62
  for param in self.wte.parameters():
63
  param.requires_grad = False
64
 
65
  for param in self.blocks.parameters():
66
  param.requires_grad = False
67
-
68
- # Add a projection layer for the custom embedding
69
- # torch.set_default_dtype(torch.bfloat16)
70
- self.modality0_embedding_projection = nn.ModuleList([nn.Linear(modality0_dim, config.d_model),
71
- # nn.BatchNorm1d(config.d_model),
72
- nn.ReLU(),
73
- nn.Linear(config.d_model, config.d_model),
74
- # nn.BatchNorm1d(config.d_model),
75
- nn.ReLU(),
76
- nn.Linear(config.d_model, config.d_model)])# nn.Linear(modality0_dim, self.hidden_size)
77
-
78
 
79
- self.modality2_embedding_projection = nn.ModuleList([nn.Linear(modality2_dim, config.d_model),
80
- # nn.BatchNorm1d(config.d_model),
81
- nn.ReLU(),
82
- nn.Linear(config.d_model, config.d_model),
83
- # nn.BatchNorm1d(config.d_model),
84
- nn.ReLU(),
85
- nn.Linear(config.d_model, config.d_model)])# nn.Linear(modality0_dim, self.hidden_size)
86
-
87
-
88
- ### Added for P3GPT - FINISH
89
-
90
  self.rope = config.attn_config['rope']
91
  self.rope_impl = None
92
  if self.rope:
93
  self.rope_impl = config.attn_config['rope_impl']
94
- self.rotary_embedding = gen_rotary_embedding(rope_head_dim=config.d_model // config.n_heads, rope_impl=self.rope_impl, rope_theta=config.attn_config['rope_theta'], rope_dail_config=config.attn_config['rope_dail_config'], rope_hf_config=config.attn_config['rope_hf_config'], max_seq_len=self.config.max_seq_len)
95
- if config.init_device != 'meta':
96
- log.info(f'We recommend using config.init_device="meta" with Composer + FSDP for faster initialization.')
97
- self.apply(self.param_init_fn)
 
 
 
 
 
98
  self.is_causal = not self.prefix_lm
99
  self._attn_bias_initialized = False
100
  self.attn_bias = None
101
- self.attn_bias_shape = attn_bias_shape(self.attn_impl, config.n_heads, config.max_seq_len, self.alibi, prefix_lm=self.prefix_lm, causal=self.is_causal, use_sequence_id=self.attn_uses_sequence_id)
 
 
 
 
 
 
 
 
 
102
  if config.no_bias:
103
  for module in self.modules():
104
  if hasattr(module, 'bias') and isinstance(module.bias, nn.Parameter):
@@ -107,31 +120,93 @@ class Custom_MptModel(MPTModel): # MptModel
107
  if hasattr(module, 'use_bias'):
108
  log.info(f'Setting use_bias=False for module={module!r}.')
109
  module.use_bias = False
 
110
  log.debug(self)
111
  log.debug(f"Using {self.config.init_config['name']} initialization.")
112
 
113
- # Initialize weights and apply final processing
114
- # self.post_init()
 
115
 
 
 
116
 
117
- def get_input_embeddings(self):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
  return self.wte
119
 
 
 
 
120
 
121
- def set_input_embeddings(self, new_embeddings):
122
- # self.wte = new_embeddings
 
123
  self.wte.weight = new_embeddings
124
 
125
-
126
- def forward(self, input_ids: Optional[torch.LongTensor]=None, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None,
127
- attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None,
128
- sequence_id: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None,
129
- output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None,
130
- inputs_embeds: Optional[torch.Tensor]=None, modality0_emb: Optional[bool] = None,
131
- modality0_token_id: Optional[bool] = None, modality1_emb: Optional[bool] = None, modality1_token_id: Optional[bool] = None,
132
- modality2_emb: Optional[bool] = None, modality2_token_id: Optional[bool] = None, modality3_emb: Optional[bool] = None,
133
- modality3_token_id: Optional[bool] = None,) -> BaseModelOutputWithPast:
134
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
  return_dict = return_dict if return_dict is not None else self.config.return_dict
136
  use_cache = use_cache if use_cache is not None else self.config.use_cache
137
  if attention_mask is not None:
@@ -152,59 +227,13 @@ class Custom_MptModel(MPTModel): # MptModel
152
  raise ValueError('sequence_id is a required argument when MPT is configured with attn_uses_sequence_id=True ' + 'and the model is in train mode.')
153
  elif self.attn_uses_sequence_id is False and sequence_id is not None:
154
  warnings.warn('MPT received non-None input for `sequence_id` but is configured with attn_uses_sequence_id=False. ' + 'This input will be ignored. If you want the model to use `sequence_id`, set attn_uses_sequence_id to True.')
155
-
156
- ### ADDED FOR P3 - START
157
-
158
- if modality0_emb is not None:
159
- modality0_emb = torch.tensor(modality0_emb, dtype=torch.bfloat16)
160
- hidden_states = self.wte.weight.detach()
161
 
162
- for layer in self.modality0_embedding_projection:
163
- modality0_emb = layer(modality0_emb)
164
- proj_modality0_emb = modality0_emb
 
 
165
 
166
- # Replace the original embedding for the custom token with the custom embedding
167
- hidden_states[modality0_token_id, :] = torch.mean(torch.squeeze(proj_modality0_emb, 1), dim=0)
168
- self.set_input_embeddings(torch.nn.Parameter(hidden_states))
169
-
170
- if modality1_emb is not None:
171
- modality1_emb = torch.tensor(modality1_emb, dtype=torch.bfloat16)
172
- hidden_states = self.wte.weight.detach()
173
-
174
- for layer in self.modality0_embedding_projection:
175
- modality1_emb = layer(modality1_emb)
176
- proj_modality1_emb = modality1_emb
177
-
178
- # Replace the original embedding for the custom token with the custom embedding
179
- hidden_states[modality1_token_id, :] = torch.mean(torch.squeeze(proj_modality1_emb, 1), dim=0)
180
- self.set_input_embeddings(torch.nn.Parameter(hidden_states))
181
-
182
- if modality2_emb is not None:
183
- modality2_emb = torch.tensor(modality2_emb, dtype=torch.bfloat16)
184
- hidden_states = self.wte.weight.detach()
185
-
186
- for layer in self.modality2_embedding_projection:
187
- modality2_emb = layer(modality2_emb)
188
- proj_modality2_emb = modality2_emb
189
-
190
- # Replace the original embedding for the custom token with the custom embedding
191
- hidden_states[modality2_token_id, :] = torch.mean(torch.squeeze(proj_modality2_emb, 1), dim=0)
192
- self.set_input_embeddings(torch.nn.Parameter(hidden_states))
193
-
194
- if modality3_emb is not None:
195
- modality3_emb = torch.tensor(modality3_emb, dtype=torch.bfloat16)
196
- hidden_states = self.wte.weight.detach()
197
-
198
- for layer in self.modality2_embedding_projection:
199
- modality3_emb = layer(modality3_emb)
200
- proj_modality3_emb = modality3_emb
201
-
202
- # Replace the original embedding for the custom token with the custom embedding
203
- hidden_states[modality3_token_id, :] = torch.mean(torch.squeeze(proj_modality3_emb, 1), dim=0)
204
- self.set_input_embeddings(torch.nn.Parameter(hidden_states))
205
-
206
- ### ADDED FOR P3 - END
207
-
208
  if input_ids is not None and inputs_embeds is not None:
209
  raise ValueError('You cannot specify both input_ids and inputs_embeds.')
210
  elif input_ids is not None:
@@ -219,15 +248,18 @@ class Custom_MptModel(MPTModel): # MptModel
219
  input_device = inputs_embeds.device
220
  else:
221
  raise ValueError('You must specify input_ids or inputs_embeds')
 
222
  assert S <= self.config.max_seq_len, f'Cannot forward input with seq_len={S}, this model only supports seq_len<={self.config.max_seq_len}'
223
  rotary_emb_w_meta_info = None
224
  past_position = 0
 
225
  if past_key_values is not None:
226
  if len(past_key_values) != self.config.n_layers:
227
  raise ValueError(f'past_key_values must provide a past_key_value for each attention ' + f'layer in the network (len(past_key_values)={len(past_key_values)!r}; self.config.n_layers={self.config.n_layers!r}).')
228
  past_position = past_key_values[0][0].size(1)
229
  if self.attn_impl == 'torch':
230
  past_position = past_key_values[0][0].size(3)
 
231
  if self.learned_pos_emb or self.rope:
232
  if self.learned_pos_emb and S + past_position > self.config.max_seq_len:
233
  raise ValueError(f'Cannot forward input with past sequence length {past_position} and current sequence length ' + f'{S + 1}, this model only supports total sequence length <= {self.config.max_seq_len}.')
@@ -241,59 +273,102 @@ class Custom_MptModel(MPTModel): # MptModel
241
  rotary_emb_w_meta_info = {'impl': self.rope_impl, 'rotary_emb': self.rotary_embedding, 'offset_info': pos, 'seq_len': S + past_position}
242
  elif self.rope and self.rope_impl == 'dail':
243
  rotary_emb_w_meta_info = {'impl': self.rope_impl, 'rotary_emb': self.rotary_embedding, 'offset_info': past_position, 'seq_len': S + past_position}
 
 
244
  if self.embedding_fraction == 1:
245
  x = self.emb_drop(x)
246
  else:
247
  x_shrunk = x * self.embedding_fraction + x.detach() * (1 - self.embedding_fraction)
248
  assert isinstance(self.emb_drop, nn.Module)
249
  x = self.emb_drop(x_shrunk)
 
250
  (attn_bias, attention_mask) = self._attn_bias(device=x.device, dtype=torch.float32, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id)
251
- attention_mask_in_length = gen_attention_mask_in_length(sequence_id=sequence_id, S=S, attn_uses_sequence_id=self.attn_uses_sequence_id, attn_impl=self.attn_impl, attention_mask=attention_mask)
 
 
 
252
  alibi_slopes = None
253
  if self.alibi and self.attn_impl == 'flash':
254
  alibi_slopes = gen_slopes(n_heads=self.config.n_heads, alibi_bias_max=self.alibi_bias_max, device=x.device, return_1d=True)
255
-
256
  presents = () if use_cache else None
257
  if use_cache and past_key_values is None:
258
  past_key_values = [() for _ in range(self.config.n_layers)]
259
  all_hidden_states = () if output_hidden_states else None
260
  all_self_attns = () if output_attentions else None
 
261
  flash_attn_padding_info = {}
262
  if self.attn_impl == 'flash':
263
  flash_attn_padding_info = gen_flash_attn_padding_info(bsz, S, past_position, x.device, attention_mask_in_length, attention_mask)
 
264
  for (b_idx, block) in enumerate(self.blocks):
265
  if output_hidden_states:
266
  assert all_hidden_states is not None
267
  all_hidden_states = all_hidden_states + (x,)
268
  past_key_value = past_key_values[b_idx] if past_key_values is not None else None
269
  (x, attn_weights, present) = block(x, past_key_value=past_key_value, attn_bias=attn_bias, rotary_emb_w_meta_info=rotary_emb_w_meta_info, attention_mask=attention_mask, is_causal=self.is_causal, output_attentions=bool(output_attentions), alibi_slopes=alibi_slopes, flash_attn_padding_info=flash_attn_padding_info)
 
270
  if presents is not None:
271
  presents += (present,)
272
  if output_attentions:
273
  assert all_self_attns is not None
274
  all_self_attns = all_self_attns + (attn_weights,)
 
275
  x = self.norm_f(x)
 
276
  if output_hidden_states:
277
  assert all_hidden_states is not None
278
  all_hidden_states = all_hidden_states + (x,)
279
  return BaseModelOutputWithPast(last_hidden_state=x, past_key_values=presents, hidden_states=all_hidden_states, attentions=all_self_attns)
280
 
 
 
 
 
 
 
 
 
 
 
 
 
281
 
282
- class Custom_MPTForCausalLM(MPTForCausalLM):
 
 
 
 
 
283
 
284
- def __init__(self, config: MPTConfig):
 
 
 
 
 
 
 
 
 
 
 
285
  super().__init__(config)
286
- # log.info(f'Instantiating an MPTForCausalLM model from {__file__}')
287
- self.transformer: MPTModel = Custom_MptModel(config)
 
288
  self.lm_head = None
 
289
  if not config.tie_word_embeddings:
290
  self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False, device=config.init_device)
291
  self.lm_head._fsdp_wrap = True
 
292
  for child in self.transformer.children():
293
  if isinstance(child, torch.nn.ModuleList):
294
  continue
295
  if isinstance(child, torch.nn.Module):
296
  child._fsdp_wrap = True
 
297
  self.logit_scale = None
298
  if config.logit_scale is not None:
299
  logit_scale = config.logit_scale
@@ -304,21 +379,69 @@ class Custom_MPTForCausalLM(MPTForCausalLM):
304
  raise ValueError(f"logit_scale={logit_scale!r} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'.")
305
  self.logit_scale = logit_scale
306
 
307
- def forward(self, input_ids: Optional[torch.LongTensor]=None, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None,
308
- attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None,
309
- sequence_id: Optional[torch.LongTensor]=None, labels: Optional[torch.LongTensor]=None,
310
- return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None,
311
- use_cache: Optional[bool]=None, inputs_embeds: Optional[torch.FloatTensor]=None,
312
- modality0_emb: Optional[bool] = None, modality0_token_id: Optional[bool] = None,
313
- modality1_emb: Optional[bool] = None, modality1_token_id: Optional[bool] = None,
314
- modality2_emb: Optional[bool] = None, modality2_token_id: Optional[bool] = None,
315
- modality3_emb: Optional[bool] = None, modality3_token_id: Optional[bool] = None) -> CausalLMOutputWithPast:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
316
  return_dict = return_dict if return_dict is not None else self.config.return_dict
317
  use_cache = use_cache if use_cache is not None else self.config.use_cache
 
318
  outputs = self.transformer(
319
- input_ids=input_ids, past_key_values=past_key_values, attention_mask=attention_mask, prefix_mask=prefix_mask,
320
- sequence_id=sequence_id, return_dict=return_dict, output_attentions=output_attentions, output_hidden_states=output_hidden_states,
321
- use_cache=use_cache, inputs_embeds=inputs_embeds,
 
 
 
 
 
 
 
322
  modality0_emb=modality0_emb,
323
  modality0_token_id=modality0_token_id,
324
  modality1_emb=modality1_emb,
@@ -328,19 +451,23 @@ class Custom_MPTForCausalLM(MPTForCausalLM):
328
  modality3_emb=modality3_emb,
329
  modality3_token_id=modality3_token_id
330
  )
 
331
  if self.lm_head is not None:
332
  logits = self.lm_head(outputs.last_hidden_state)
333
  else:
334
  out = outputs.last_hidden_state
335
  out = out.to(self.transformer.wte.weight.device)
336
  logits = self.transformer.wte(out, True)
 
337
  if self.logit_scale is not None:
338
  if self.logit_scale == 0:
339
  warnings.warn(f'Multiplying logits by self.logit_scale={self.logit_scale!r}. This will produce uniform (uninformative) outputs.')
340
  logits *= self.logit_scale
 
341
  loss = None
342
  if labels is not None:
343
  _labels = torch.roll(labels, shifts=-1)
344
  _labels[:, -1] = -100
345
  loss = F.cross_entropy(logits.view(-1, logits.size(-1)), _labels.to(logits.device).view(-1))
346
- return CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions)
 
 
1
  from typing import Optional, Tuple, Union, List
2
 
3
+ import logging
4
+ import math
5
+ import warnings
6
+
7
  import torch
8
  import torch.nn as nn
 
 
 
 
 
 
 
9
  import torch.nn.functional as F
10
+ from transformers import PreTrainedTokenizerFast
11
+ from transformers.modeling_outputs import (
12
+ CausalLMOutputWithCrossAttentions,
13
+ CausalLMOutputWithPast,
14
+ BaseModelOutputWithPastAndCrossAttentions,
15
+ BaseModelOutputWithPast,
16
+ )
17
 
18
  from mpt_7b.modeling_mpt import MPTModel, MPTForCausalLM, gen_attention_mask_in_length
19
  from mpt_7b.configuration_mpt import MPTConfig
 
22
  from mpt_7b.custom_embedding import SharedEmbedding
23
  from mpt_7b.attention import ATTN_CLASS_REGISTRY, attn_bias_shape, build_attn_bias, gen_slopes
24
 
 
25
  log = logging.getLogger(__name__)
26
 
27
 
28
+ class Custom_MptModel(MPTModel):
29
+ """
30
+ Custom MPT Model that extends the base MPTModel with additional functionalities
31
+ for handling multimodal embeddings and custom projections.
32
 
33
+ Args:
34
+ config (MPTConfig): Configuration object containing model parameters.
35
+ modality0_dim (int): Dimension for the first modality embedding.
36
+ modality2_dim (int): Dimension for the second modality embedding.
37
+ """
38
 
39
+ def __init__(self, config: MPTConfig, modality0_dim: int = 128, modality2_dim: int = 1536):
 
40
  config._validate_config()
41
  super().__init__(config)
42
+
43
+ # Initialize model parameters based on the configuration
44
  self.attn_impl = config.attn_config['attn_impl']
45
  self.prefix_lm = config.attn_config['prefix_lm']
46
  self.attn_uses_sequence_id = config.attn_config['attn_uses_sequence_id']
47
  self.alibi = config.attn_config['alibi']
48
  self.alibi_bias_max = config.attn_config['alibi_bias_max']
49
  self.learned_pos_emb = config.learned_pos_emb
50
+
51
+ # Set initialization device
52
  if config.init_device == 'mixed':
53
  if dist.get_local_rank() == 0:
54
  config.init_device = 'cpu'
55
  else:
56
  config.init_device = 'meta'
57
+
58
  if config.norm_type.lower() not in NORM_CLASS_REGISTRY.keys():
59
  norm_options = ' | '.join(NORM_CLASS_REGISTRY.keys())
60
  raise NotImplementedError(f'Requested norm type ({config.norm_type}) is not implemented within this repo (Options: {norm_options}).')
61
+
62
  norm_class = NORM_CLASS_REGISTRY[config.norm_type.lower()]
63
  self.embedding_fraction = config.embedding_fraction
64
+
65
+ # Initialize embeddings
66
  self.wte = SharedEmbedding(config.vocab_size, config.d_model, device=config.init_device)
67
+
68
  if self.learned_pos_emb:
69
+ self.wpe = nn.Embedding(config.max_seq_len, config.d_model, device=config.init_device)
70
+
71
  self.emb_drop = nn.Dropout(config.emb_pdrop)
72
+
73
+ # Initialize model blocks
74
  self.blocks = nn.ModuleList([MPTBlock(device=config.init_device, **config.to_dict()) for _ in range(config.n_layers)])
75
  self.norm_f = norm_class(config.d_model, device=config.init_device)
76
+
 
 
77
  # Freeze all parameters except the projection layer
78
  for param in self.wte.parameters():
79
  param.requires_grad = False
80
 
81
  for param in self.blocks.parameters():
82
  param.requires_grad = False
 
 
 
 
 
 
 
 
 
 
 
83
 
84
+ # Initialize projections for different modalities
85
+ self.modality0_embedding_projection = self._create_modal_projection(modality0_dim)
86
+ self.modality2_embedding_projection = self._create_modal_projection(modality2_dim)
87
+
88
+ # Other configurations
 
 
 
 
 
 
89
  self.rope = config.attn_config['rope']
90
  self.rope_impl = None
91
  if self.rope:
92
  self.rope_impl = config.attn_config['rope_impl']
93
+ self.rotary_embedding = gen_rotary_embedding(
94
+ rope_head_dim=config.d_model // config.n_heads,
95
+ rope_impl=self.rope_impl,
96
+ rope_theta=config.attn_config['rope_theta'],
97
+ rope_dail_config=config.attn_config['rope_dail_config'],
98
+ rope_hf_config=config.attn_config['rope_hf_config'],
99
+ max_seq_len=self.config.max_seq_len
100
+ )
101
+
102
  self.is_causal = not self.prefix_lm
103
  self._attn_bias_initialized = False
104
  self.attn_bias = None
105
+ self.attn_bias_shape = attn_bias_shape(
106
+ self.attn_impl,
107
+ config.n_heads,
108
+ config.max_seq_len,
109
+ self.alibi,
110
+ prefix_lm=self.prefix_lm,
111
+ causal=self.is_causal,
112
+ use_sequence_id=self.attn_uses_sequence_id
113
+ )
114
+
115
  if config.no_bias:
116
  for module in self.modules():
117
  if hasattr(module, 'bias') and isinstance(module.bias, nn.Parameter):
 
120
  if hasattr(module, 'use_bias'):
121
  log.info(f'Setting use_bias=False for module={module!r}.')
122
  module.use_bias = False
123
+
124
  log.debug(self)
125
  log.debug(f"Using {self.config.init_config['name']} initialization.")
126
 
127
+ def _create_modal_projection(self, modality_dim: int) -> nn.ModuleList:
128
+ """
129
+ Create a projection layer for a given modality.
130
 
131
+ Args:
132
+ modality_dim (int): Dimension of the modality embedding.
133
 
134
+ Returns:
135
+ nn.ModuleList: A module list containing layers for modal projection.
136
+ """
137
+ return nn.ModuleList([
138
+ nn.Linear(modality_dim, self.config.d_model),
139
+ nn.ReLU(),
140
+ nn.Linear(self.config.d_model, self.config.d_model),
141
+ nn.ReLU(),
142
+ nn.Linear(self.config.d_model, self.config.d_model)
143
+ ])
144
+
145
+ def get_input_embeddings(self) -> nn.Embedding:
146
+ """
147
+ Get the input word embeddings.
148
+
149
+ Returns:
150
+ nn.Embedding: The word token embeddings.
151
+ """
152
  return self.wte
153
 
154
+ def set_input_embeddings(self, new_embeddings: nn.Parameter):
155
+ """
156
+ Set the input word embeddings with new embeddings.
157
 
158
+ Args:
159
+ new_embeddings (nn.Parameter): The new word embeddings to set.
160
+ """
161
  self.wte.weight = new_embeddings
162
 
163
+ def forward(
164
+ self,
165
+ input_ids: Optional[torch.LongTensor] = None,
166
+ past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None,
167
+ attention_mask: Optional[torch.ByteTensor] = None,
168
+ prefix_mask: Optional[torch.ByteTensor] = None,
169
+ sequence_id: Optional[torch.LongTensor] = None,
170
+ return_dict: Optional[bool] = None,
171
+ output_attentions: Optional[bool] = None,
172
+ output_hidden_states: Optional[bool] = None,
173
+ use_cache: Optional[bool] = None,
174
+ inputs_embeds: Optional[torch.Tensor] = None,
175
+ modality0_emb: Optional[bool] = None,
176
+ modality0_token_id: Optional[bool] = None,
177
+ modality1_emb: Optional[bool] = None,
178
+ modality1_token_id: Optional[bool] = None,
179
+ modality2_emb: Optional[bool] = None,
180
+ modality2_token_id: Optional[bool] = None,
181
+ modality3_emb: Optional[bool] = None,
182
+ modality3_token_id: Optional[bool] = None
183
+ ) -> BaseModelOutputWithPast:
184
+ """
185
+ Forward pass for the model, processing input through the network.
186
+
187
+ Args:
188
+ input_ids (Optional[torch.LongTensor]): Input tensor representing token IDs.
189
+ past_key_values (Optional[List[Tuple[torch.FloatTensor]]]): Past key values for cache.
190
+ attention_mask (Optional[torch.ByteTensor]): Attention mask to avoid attention to padding tokens.
191
+ prefix_mask (Optional[torch.ByteTensor]): Mask for the prefix input.
192
+ sequence_id (Optional[torch.LongTensor]): Sequence ID for token sequences.
193
+ return_dict (Optional[bool]): Whether to return a dict or a tuple.
194
+ output_attentions (Optional[bool]): Whether to output attention weights.
195
+ output_hidden_states (Optional[bool]): Whether to output hidden states.
196
+ use_cache (Optional[bool]): Whether to cache past key values.
197
+ inputs_embeds (Optional[torch.Tensor]): Input tensor representing embeddings.
198
+ modality0_emb (Optional[bool]): Modality 0 (KG UP genes) embedding.
199
+ modality0_token_id (Optional[bool]): Token ID for modality 0.
200
+ modality1_emb (Optional[bool]): Modality 1 (KG DOWN genes) embedding.
201
+ modality1_token_id (Optional[bool]): Token ID for modality 1.
202
+ modality2_emb (Optional[bool]): Modality 2 (TEXT UP genes) embedding.
203
+ modality2_token_id (Optional[bool]): Token ID for modality 2.
204
+ modality3_emb (Optional[bool]): Modality 3 (TEXT DOWN genes) embedding.
205
+ modality3_token_id (Optional[bool]): Token ID for modality 3.
206
+
207
+ Returns:
208
+ BaseModelOutputWithPast: Model output containing last hidden state and optional details.
209
+ """
210
  return_dict = return_dict if return_dict is not None else self.config.return_dict
211
  use_cache = use_cache if use_cache is not None else self.config.use_cache
212
  if attention_mask is not None:
 
227
  raise ValueError('sequence_id is a required argument when MPT is configured with attn_uses_sequence_id=True ' + 'and the model is in train mode.')
228
  elif self.attn_uses_sequence_id is False and sequence_id is not None:
229
  warnings.warn('MPT received non-None input for `sequence_id` but is configured with attn_uses_sequence_id=False. ' + 'This input will be ignored. If you want the model to use `sequence_id`, set attn_uses_sequence_id to True.')
 
 
 
 
 
 
230
 
231
+ # Process modality embeddings for each modality
232
+ self._process_modalities(modality0_emb, modality0_token_id, self.modality0_embedding_projection)
233
+ self._process_modalities(modality1_emb, modality1_token_id, self.modality0_embedding_projection)
234
+ self._process_modalities(modality2_emb, modality2_token_id, self.modality2_embedding_projection)
235
+ self._process_modalities(modality3_emb, modality3_token_id, self.modality2_embedding_projection)
236
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
237
  if input_ids is not None and inputs_embeds is not None:
238
  raise ValueError('You cannot specify both input_ids and inputs_embeds.')
239
  elif input_ids is not None:
 
248
  input_device = inputs_embeds.device
249
  else:
250
  raise ValueError('You must specify input_ids or inputs_embeds')
251
+
252
  assert S <= self.config.max_seq_len, f'Cannot forward input with seq_len={S}, this model only supports seq_len<={self.config.max_seq_len}'
253
  rotary_emb_w_meta_info = None
254
  past_position = 0
255
+
256
  if past_key_values is not None:
257
  if len(past_key_values) != self.config.n_layers:
258
  raise ValueError(f'past_key_values must provide a past_key_value for each attention ' + f'layer in the network (len(past_key_values)={len(past_key_values)!r}; self.config.n_layers={self.config.n_layers!r}).')
259
  past_position = past_key_values[0][0].size(1)
260
  if self.attn_impl == 'torch':
261
  past_position = past_key_values[0][0].size(3)
262
+
263
  if self.learned_pos_emb or self.rope:
264
  if self.learned_pos_emb and S + past_position > self.config.max_seq_len:
265
  raise ValueError(f'Cannot forward input with past sequence length {past_position} and current sequence length ' + f'{S + 1}, this model only supports total sequence length <= {self.config.max_seq_len}.')
 
273
  rotary_emb_w_meta_info = {'impl': self.rope_impl, 'rotary_emb': self.rotary_embedding, 'offset_info': pos, 'seq_len': S + past_position}
274
  elif self.rope and self.rope_impl == 'dail':
275
  rotary_emb_w_meta_info = {'impl': self.rope_impl, 'rotary_emb': self.rotary_embedding, 'offset_info': past_position, 'seq_len': S + past_position}
276
+
277
+ # Handle embedding fraction
278
  if self.embedding_fraction == 1:
279
  x = self.emb_drop(x)
280
  else:
281
  x_shrunk = x * self.embedding_fraction + x.detach() * (1 - self.embedding_fraction)
282
  assert isinstance(self.emb_drop, nn.Module)
283
  x = self.emb_drop(x_shrunk)
284
+
285
  (attn_bias, attention_mask) = self._attn_bias(device=x.device, dtype=torch.float32, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id)
286
+ attention_mask_in_length = gen_attention_mask_in_length(sequence_id=sequence_id, S=S,
287
+ attn_uses_sequence_id=self.attn_uses_sequence_id,
288
+ attn_impl=self.attn_impl,
289
+ attention_mask=attention_mask)
290
  alibi_slopes = None
291
  if self.alibi and self.attn_impl == 'flash':
292
  alibi_slopes = gen_slopes(n_heads=self.config.n_heads, alibi_bias_max=self.alibi_bias_max, device=x.device, return_1d=True)
293
+
294
  presents = () if use_cache else None
295
  if use_cache and past_key_values is None:
296
  past_key_values = [() for _ in range(self.config.n_layers)]
297
  all_hidden_states = () if output_hidden_states else None
298
  all_self_attns = () if output_attentions else None
299
+
300
  flash_attn_padding_info = {}
301
  if self.attn_impl == 'flash':
302
  flash_attn_padding_info = gen_flash_attn_padding_info(bsz, S, past_position, x.device, attention_mask_in_length, attention_mask)
303
+
304
  for (b_idx, block) in enumerate(self.blocks):
305
  if output_hidden_states:
306
  assert all_hidden_states is not None
307
  all_hidden_states = all_hidden_states + (x,)
308
  past_key_value = past_key_values[b_idx] if past_key_values is not None else None
309
  (x, attn_weights, present) = block(x, past_key_value=past_key_value, attn_bias=attn_bias, rotary_emb_w_meta_info=rotary_emb_w_meta_info, attention_mask=attention_mask, is_causal=self.is_causal, output_attentions=bool(output_attentions), alibi_slopes=alibi_slopes, flash_attn_padding_info=flash_attn_padding_info)
310
+
311
  if presents is not None:
312
  presents += (present,)
313
  if output_attentions:
314
  assert all_self_attns is not None
315
  all_self_attns = all_self_attns + (attn_weights,)
316
+
317
  x = self.norm_f(x)
318
+
319
  if output_hidden_states:
320
  assert all_hidden_states is not None
321
  all_hidden_states = all_hidden_states + (x,)
322
  return BaseModelOutputWithPast(last_hidden_state=x, past_key_values=presents, hidden_states=all_hidden_states, attentions=all_self_attns)
323
 
324
+ def _process_modalities(self, modality_emb: Optional[bool], token_id: Optional[bool], projection: nn.ModuleList):
325
+ """
326
+ Process the modality embedding if provided, updating the input embeddings.
327
+
328
+ Args:
329
+ modality_emb (Optional[bool]): The modality embedding to process.
330
+ token_id (Optional[bool]): The token ID for the modality.
331
+ projection (nn.ModuleList): The projection layers for the modality.
332
+ """
333
+ if modality_emb is not None:
334
+ modality_emb = torch.tensor(modality_emb, dtype=torch.bfloat16)
335
+ hidden_states = self.wte.weight.detach()
336
 
337
+ for layer in projection:
338
+ modality_emb = layer(modality_emb)
339
+
340
+ proj_modality_emb = modality_emb
341
+ hidden_states[token_id, :] = torch.mean(torch.squeeze(proj_modality_emb, 1), dim=0)
342
+ self.set_input_embeddings(torch.nn.Parameter(hidden_states))
343
 
344
+
345
+ class Precious3MPTForCausalLM(MPTForCausalLM):
346
+ """
347
+ Precious3 MPT For Causal Language Modeling that utilizes the Custom_MptModel.
348
+
349
+ Args:
350
+ config (MPTConfig): Configuration object for the model.
351
+ modality0_dim (int): Dimension for the first modality embedding.
352
+ modality2_dim (int): Dimension for the second modality embedding.
353
+ """
354
+
355
+ def __init__(self, config: MPTConfig, modality0_dim: int = 128, modality2_dim: int = 1536):
356
  super().__init__(config)
357
+
358
+ # Pass the modalities dimensions to Custom_MptModel
359
+ self.transformer: MPTModel = Custom_MptModel(config, modality0_dim=modality0_dim, modality2_dim=modality2_dim)
360
  self.lm_head = None
361
+
362
  if not config.tie_word_embeddings:
363
  self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False, device=config.init_device)
364
  self.lm_head._fsdp_wrap = True
365
+
366
  for child in self.transformer.children():
367
  if isinstance(child, torch.nn.ModuleList):
368
  continue
369
  if isinstance(child, torch.nn.Module):
370
  child._fsdp_wrap = True
371
+
372
  self.logit_scale = None
373
  if config.logit_scale is not None:
374
  logit_scale = config.logit_scale
 
379
  raise ValueError(f"logit_scale={logit_scale!r} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'.")
380
  self.logit_scale = logit_scale
381
 
382
+ def forward(
383
+ self,
384
+ input_ids: Optional[torch.LongTensor] = None,
385
+ past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None,
386
+ attention_mask: Optional[torch.ByteTensor] = None,
387
+ prefix_mask: Optional[torch.ByteTensor] = None,
388
+ sequence_id: Optional[torch.LongTensor] = None,
389
+ labels: Optional[torch.LongTensor] = None,
390
+ return_dict: Optional[bool] = None,
391
+ output_attentions: Optional[bool] = None,
392
+ output_hidden_states: Optional[bool] = None,
393
+ use_cache: Optional[bool] = None,
394
+ inputs_embeds: Optional[torch.FloatTensor] = None,
395
+ modality0_emb: Optional[bool] = None,
396
+ modality0_token_id: Optional[bool] = None,
397
+ modality1_emb: Optional[bool] = None,
398
+ modality1_token_id: Optional[bool] = None,
399
+ modality2_emb: Optional[bool] = None,
400
+ modality2_token_id: Optional[bool] = None,
401
+ modality3_emb: Optional[bool] = None,
402
+ modality3_token_id: Optional[bool] = None
403
+ ) -> CausalLMOutputWithPast:
404
+ """
405
+ Forward pass through the causal language model.
406
+
407
+ Args:
408
+ input_ids (Optional[torch.LongTensor]): Input tensor for token IDs.
409
+ past_key_values (Optional[List[Tuple[torch.FloatTensor]]]): Past key values for cached states.
410
+ attention_mask (Optional[torch.ByteTensor]): Attention mask to prevent attention to padding tokens.
411
+ prefix_mask (Optional[torch.ByteTensor]): Mask for prefix inputs.
412
+ sequence_id (Optional[torch.LongTensor]): Sequence ID tensor.
413
+ labels (Optional[torch.LongTensor]): Labels for the loss computation, if applicable.
414
+ return_dict (Optional[bool]): Whether to return outputs as a dict or tuple.
415
+ output_attentions (Optional[bool]): Whether to return attention weights.
416
+ output_hidden_states (Optional[bool]): Whether to return hidden states.
417
+ use_cache (Optional[bool]): Whether to cache past key values.
418
+ inputs_embeds (Optional[torch.FloatTensor]): Input tensor for embeddings.
419
+ modality0_emb (Optional[bool]): Input for modality 0.
420
+ modality0_token_id (Optional[bool]): Token ID for modality 0.
421
+ modality1_emb (Optional[bool]): Input for modality 1.
422
+ modality1_token_id (Optional[bool]): Token ID for modality 1.
423
+ modality2_emb (Optional[bool]): Input for modality 2.
424
+ modality2_token_id (Optional[bool]): Token ID for modality 2.
425
+ modality3_emb (Optional[bool]): Input for modality 3.
426
+ modality3_token_id (Optional[bool]): Token ID for modality 3.
427
+
428
+ Returns:
429
+ CausalLMOutputWithPast: Causal language model output containing logits and past key values.
430
+ """
431
  return_dict = return_dict if return_dict is not None else self.config.return_dict
432
  use_cache = use_cache if use_cache is not None else self.config.use_cache
433
+
434
  outputs = self.transformer(
435
+ input_ids=input_ids,
436
+ past_key_values=past_key_values,
437
+ attention_mask=attention_mask,
438
+ prefix_mask=prefix_mask,
439
+ sequence_id=sequence_id,
440
+ return_dict=return_dict,
441
+ output_attentions=output_attentions,
442
+ output_hidden_states=output_hidden_states,
443
+ use_cache=use_cache,
444
+ inputs_embeds=inputs_embeds,
445
  modality0_emb=modality0_emb,
446
  modality0_token_id=modality0_token_id,
447
  modality1_emb=modality1_emb,
 
451
  modality3_emb=modality3_emb,
452
  modality3_token_id=modality3_token_id
453
  )
454
+
455
  if self.lm_head is not None:
456
  logits = self.lm_head(outputs.last_hidden_state)
457
  else:
458
  out = outputs.last_hidden_state
459
  out = out.to(self.transformer.wte.weight.device)
460
  logits = self.transformer.wte(out, True)
461
+
462
  if self.logit_scale is not None:
463
  if self.logit_scale == 0:
464
  warnings.warn(f'Multiplying logits by self.logit_scale={self.logit_scale!r}. This will produce uniform (uninformative) outputs.')
465
  logits *= self.logit_scale
466
+
467
  loss = None
468
  if labels is not None:
469
  _labels = torch.roll(labels, shifts=-1)
470
  _labels[:, -1] = -100
471
  loss = F.cross_entropy(logits.view(-1, logits.size(-1)), _labels.to(logits.device).view(-1))
472
+
473
+ return CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions)