stefan-insilico commited on
Commit
9cb7de4
1 Parent(s): 8af136c

Rename precious3_gpt_multi_modalX.py to precious3_gpt_multi_modal.py

Browse files
precious3_gpt_multi_modalX.py → precious3_gpt_multi_modal.py RENAMED
@@ -13,12 +13,12 @@ from transformers import PreTrainedTokenizerFast
13
  import os
14
  import torch.nn.functional as F
15
 
16
- from modeling_mpt import MPTModel, MPTForCausalLM, gen_attention_mask_in_length
17
- from configuration_mpt import MPTConfig
18
- from blocks import MPTBlock
19
- from norm import NORM_CLASS_REGISTRY
20
- from custom_embedding import SharedEmbedding
21
- from attention import ATTN_CLASS_REGISTRY, attn_bias_shape, build_attn_bias, gen_slopes
22
 
23
  import logging
24
  log = logging.getLogger(__name__)
@@ -85,10 +85,10 @@ class Custom_MptModel(MPTModel): # MptModel
85
 
86
 
87
  self.modality2_embedding_projection = nn.ModuleList([nn.Linear(modality2_dim, config.d_model),
88
- # nn.BatchNorm1d(config.d_model),
89
  nn.ReLU(),
90
  nn.Linear(config.d_model, config.d_model),
91
- # nn.BatchNorm1d(config.d_model),
92
  nn.ReLU(),
93
  nn.Linear(config.d_model, config.d_model)])# nn.Linear(modality0_dim, self.hidden_size)
94
 
@@ -351,4 +351,4 @@ class Custom_MPTForCausalLM(MPTForCausalLM):
351
  _labels = torch.roll(labels, shifts=-1)
352
  _labels[:, -1] = -100
353
  loss = F.cross_entropy(logits.view(-1, logits.size(-1)), _labels.to(logits.device).view(-1))
354
- return CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions)
 
13
  import os
14
  import torch.nn.functional as F
15
 
16
+ from mpt_7b.modeling_mpt import MPTModel, MPTForCausalLM, gen_attention_mask_in_length
17
+ from mpt_7b.configuration_mpt import MPTConfig
18
+ from mpt_7b.blocks import MPTBlock
19
+ from mpt_7b.norm import NORM_CLASS_REGISTRY
20
+ from mpt_7b.custom_embedding import SharedEmbedding
21
+ from mpt_7b.attention import ATTN_CLASS_REGISTRY, attn_bias_shape, build_attn_bias, gen_slopes
22
 
23
  import logging
24
  log = logging.getLogger(__name__)
 
85
 
86
 
87
  self.modality2_embedding_projection = nn.ModuleList([nn.Linear(modality2_dim, config.d_model),
88
+ # nn.BatchNorm1d(config.d_model),
89
  nn.ReLU(),
90
  nn.Linear(config.d_model, config.d_model),
91
+ # nn.BatchNorm1d(config.d_model),
92
  nn.ReLU(),
93
  nn.Linear(config.d_model, config.d_model)])# nn.Linear(modality0_dim, self.hidden_size)
94
 
 
351
  _labels = torch.roll(labels, shifts=-1)
352
  _labels[:, -1] = -100
353
  loss = F.cross_entropy(logits.view(-1, logits.size(-1)), _labels.to(logits.device).view(-1))
354
+ return CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions)