Markus28 commited on
Commit
ad76444
1 Parent(s): e55e319

feat: for converting v2, added lines to save model weights and print config

Browse files
Files changed (1) hide show
  1. convert_v2_weights.py +8 -1
convert_v2_weights.py CHANGED
@@ -131,6 +131,12 @@ new_state_dict = remap_state_dict(state_dict, config)
131
  flash_model = BertModel(config)
132
  flash_model.load_state_dict(new_state_dict)
133
 
 
 
 
 
 
 
134
  tokenizer = AutoTokenizer.from_pretrained('jinaai/jina-embeddings-v2-base-en')
135
  inp = tokenizer.batch_encode_plus(['Hello world', 'How is the weather today?', 'It is raining a lot in Berlin'], return_tensors='pt', padding=True).to('cuda')
136
  v2_model.eval()
@@ -141,4 +147,5 @@ output_v2 = v2_model(**inp)
141
  output_flash = flash_model(**inp)
142
  x = output_v2.last_hidden_state
143
  y = output_flash.last_hidden_state
144
- print(torch.abs(x - y))
 
 
131
  flash_model = BertModel(config)
132
  flash_model.load_state_dict(new_state_dict)
133
 
134
+
135
+ torch.save(new_state_dict, 'converted_weights.bin')
136
+ print(config.to_json_string())
137
+
138
+
139
+ """
140
  tokenizer = AutoTokenizer.from_pretrained('jinaai/jina-embeddings-v2-base-en')
141
  inp = tokenizer.batch_encode_plus(['Hello world', 'How is the weather today?', 'It is raining a lot in Berlin'], return_tensors='pt', padding=True).to('cuda')
142
  v2_model.eval()
 
147
  output_flash = flash_model(**inp)
148
  x = output_v2.last_hidden_state
149
  y = output_flash.last_hidden_state
150
+ print(torch.abs(x - y))
151
+ """