Upload bge_custom_impl.py with huggingface_hub
Browse files- bge_custom_impl.py +54 -0
bge_custom_impl.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
from typing import Any
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
from transformers import AutoModel, AutoTokenizer
|
7 |
+
|
8 |
+
class ConcatCustomPooling(nn.Module):
|
9 |
+
def __init__(self, model_name_or_path="BAAI/bge-large-en-v1.5",layers=None,max_seq_len=512, **kwargs):
|
10 |
+
super().__init__(**kwargs)
|
11 |
+
self.layers = layers
|
12 |
+
self.base_name = model_name_or_path
|
13 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
|
14 |
+
self.model = AutoModel.from_pretrained(model_name_or_path)
|
15 |
+
self.model.eval()
|
16 |
+
self.max_seq_len = max_seq_len
|
17 |
+
|
18 |
+
|
19 |
+
def tokenize(self, inputs: list[str]):
|
20 |
+
return self.tokenizer(inputs, padding=True, truncation=True, return_tensors="pt")
|
21 |
+
|
22 |
+
|
23 |
+
|
24 |
+
def forward(self, features: dict[str, torch.Tensor], **kwargs) -> dict[str, torch.Tensor]:
|
25 |
+
model_output = self.model(**features,output_hidden_states=True)
|
26 |
+
embeddings = model_output.hidden_states
|
27 |
+
layers_embeddings = embeddings[1:] # Remove the first which is the raw embeddings layer
|
28 |
+
number_of_layers = len(layers_embeddings)
|
29 |
+
if self.layers is None:
|
30 |
+
self.layers = list(range(number_of_layers))
|
31 |
+
cls_embeddings = torch.stack([torch.nn.functional.normalize(layer[:, 0, :], p=2, dim=1) for layer_idx,layer in enumerate(layers_embeddings) if layer_idx in self.layers], dim=1)
|
32 |
+
|
33 |
+
batch_size, layer_num, hidden_dim = cls_embeddings.shape
|
34 |
+
|
35 |
+
# Reshape to concatenate the layer_num and hidden_dim dimensions
|
36 |
+
cls_embeddings_concat = cls_embeddings.view(batch_size, -1)
|
37 |
+
return {'sentence_embedding':cls_embeddings_concat}
|
38 |
+
|
39 |
+
def get_config_dict(self) -> dict[str, Any]:
|
40 |
+
return {"model_name": self.base_name, "layers": self.layers, "max_seq_len": self.max_seq_len}
|
41 |
+
|
42 |
+
|
43 |
+
def get_max_seq_length(self) -> int:
|
44 |
+
return self.max_seq_len
|
45 |
+
|
46 |
+
def save(self, save_dir: str, **kwargs) -> None:
|
47 |
+
with open(os.path.join(save_dir, "config.json"), "w") as fOut:
|
48 |
+
json.dump(self.get_config_dict(), fOut, indent=4)
|
49 |
+
|
50 |
+
def load(self,load_dir: str, **kwargs) -> "ConcatCustomPooling":
|
51 |
+
with open(os.path.join(load_dir, "config.json")) as fIn:
|
52 |
+
config = json.load(fIn)
|
53 |
+
|
54 |
+
return ConcatCustomPooling(**config)
|