Tomor0720 commited on
Commit
e82ce45
1 Parent(s): 68f6838

Upload bge_custom_impl.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. bge_custom_impl.py +54 -0
bge_custom_impl.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ from typing import Any
4
+ import torch
5
+ import torch.nn as nn
6
+ from transformers import AutoModel, AutoTokenizer
7
+
8
+ class ConcatCustomPooling(nn.Module):
9
+ def __init__(self, model_name_or_path="BAAI/bge-large-en-v1.5",layers=None,max_seq_len=512, **kwargs):
10
+ super().__init__(**kwargs)
11
+ self.layers = layers
12
+ self.base_name = model_name_or_path
13
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
14
+ self.model = AutoModel.from_pretrained(model_name_or_path)
15
+ self.model.eval()
16
+ self.max_seq_len = max_seq_len
17
+
18
+
19
+ def tokenize(self, inputs: list[str]):
20
+ return self.tokenizer(inputs, padding=True, truncation=True, return_tensors="pt")
21
+
22
+
23
+
24
+ def forward(self, features: dict[str, torch.Tensor], **kwargs) -> dict[str, torch.Tensor]:
25
+ model_output = self.model(**features,output_hidden_states=True)
26
+ embeddings = model_output.hidden_states
27
+ layers_embeddings = embeddings[1:] # Remove the first which is the raw embeddings layer
28
+ number_of_layers = len(layers_embeddings)
29
+ if self.layers is None:
30
+ self.layers = list(range(number_of_layers))
31
+ cls_embeddings = torch.stack([torch.nn.functional.normalize(layer[:, 0, :], p=2, dim=1) for layer_idx,layer in enumerate(layers_embeddings) if layer_idx in self.layers], dim=1)
32
+
33
+ batch_size, layer_num, hidden_dim = cls_embeddings.shape
34
+
35
+ # Reshape to concatenate the layer_num and hidden_dim dimensions
36
+ cls_embeddings_concat = cls_embeddings.view(batch_size, -1)
37
+ return {'sentence_embedding':cls_embeddings_concat}
38
+
39
+ def get_config_dict(self) -> dict[str, Any]:
40
+ return {"model_name": self.base_name, "layers": self.layers, "max_seq_len": self.max_seq_len}
41
+
42
+
43
+ def get_max_seq_length(self) -> int:
44
+ return self.max_seq_len
45
+
46
+ def save(self, save_dir: str, **kwargs) -> None:
47
+ with open(os.path.join(save_dir, "config.json"), "w") as fOut:
48
+ json.dump(self.get_config_dict(), fOut, indent=4)
49
+
50
+ def load(self,load_dir: str, **kwargs) -> "ConcatCustomPooling":
51
+ with open(os.path.join(load_dir, "config.json")) as fIn:
52
+ config = json.load(fIn)
53
+
54
+ return ConcatCustomPooling(**config)