Markus28 commited on
Commit
9587227
1 Parent(s): 0211324

fixed GLU implementation, added conversion of layer norms

Browse files
Files changed (2) hide show
  1. convert_v2_weights.py +19 -1
  2. mlp.py +3 -2
convert_v2_weights.py CHANGED
@@ -1,6 +1,6 @@
1
  import re
2
  from collections import OrderedDict
3
- from transformers import AutoModel
4
  from .configuration_bert import JinaBertConfig
5
  import torch
6
  from .modeling_bert import BertModel
@@ -115,6 +115,12 @@ def remap_state_dict(state_dict, config: JinaBertConfig):
115
  decoder_bias, (0, config.vocab_size - decoder_bias.shape[0]), value=-100.0
116
  )
117
 
 
 
 
 
 
 
118
  return state_dict
119
 
120
 
@@ -124,3 +130,15 @@ state_dict = v2_model.state_dict()
124
  new_state_dict = remap_state_dict(state_dict, config)
125
  flash_model = BertModel(config)
126
  flash_model.load_state_dict(new_state_dict)
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import re
2
  from collections import OrderedDict
3
+ from transformers import AutoModel, AutoTokenizer
4
  from .configuration_bert import JinaBertConfig
5
  import torch
6
  from .modeling_bert import BertModel
 
115
  decoder_bias, (0, config.vocab_size - decoder_bias.shape[0]), value=-100.0
116
  )
117
 
118
+ # LayerNorm
119
+ def key_mapping_layernorm(key):
120
+ return re.sub(r'^encoder.layers.(\d+).mlp.layernorm.(weight|bias)', r"encoder.layers.\1.norm2.\2", key)
121
+
122
+ state_dict = OrderedDict((key_mapping_layernorm(k), v) for k, v in state_dict.items())
123
+
124
  return state_dict
125
 
126
 
 
130
  new_state_dict = remap_state_dict(state_dict, config)
131
  flash_model = BertModel(config)
132
  flash_model.load_state_dict(new_state_dict)
133
+
134
+ tokenizer = AutoTokenizer.from_pretrained('jinaai/jina-embeddings-v2-base-en')
135
+ inp = tokenizer.batch_encode_plus(['Hello world', 'How is the weather today?', 'It is raining a lot in Berlin'], return_tensors='pt', padding=True).to('cuda')
136
+ v2_model.eval()
137
+ flash_model.eval()
138
+ v2_model = v2_model.to('cuda', torch.float16)
139
+ flash_model = flash_model.to('cuda', torch.float16)
140
+ output_v2 = v2_model(**inp)
141
+ output_flash = flash_model(**inp)
142
+ x = output_v2.last_hidden_state
143
+ y = output_flash.last_hidden_state
144
+ print(torch.abs(x - y))
mlp.py CHANGED
@@ -37,6 +37,7 @@ class GLUMLP(nn.Module):
37
  hidden_dropout_prob=0.1
38
  ):
39
  super().__init__()
 
40
  self.gated_layers = nn.Linear(
41
  in_features, hidden_features * 2, bias=False
42
  )
@@ -57,8 +58,8 @@ class GLUMLP(nn.Module):
57
  residual_connection = hidden_states
58
  # compute the activation
59
  hidden_states = self.gated_layers(hidden_states)
60
- gated = hidden_states[:, :, : self.config.intermediate_size]
61
- non_gated = hidden_states[:, :, self.config.intermediate_size :]
62
  hidden_states = self.act(gated) * non_gated
63
  hidden_states = self.dropout(hidden_states)
64
  # multiply by the second matrix
 
37
  hidden_dropout_prob=0.1
38
  ):
39
  super().__init__()
40
+ self.hidden_features = hidden_features
41
  self.gated_layers = nn.Linear(
42
  in_features, hidden_features * 2, bias=False
43
  )
 
58
  residual_connection = hidden_states
59
  # compute the activation
60
  hidden_states = self.gated_layers(hidden_states)
61
+ gated = hidden_states[:, : self.hidden_features]
62
+ non_gated = hidden_states[:, self.hidden_features :]
63
  hidden_states = self.act(gated) * non_gated
64
  hidden_states = self.dropout(hidden_states)
65
  # multiply by the second matrix