# modified from https://github.com/FlagOpen/FlagEmbedding/blob/master/FlagEmbedding/BGE_M3/modeling.py import os import torch from torch import nn, Tensor from transformers import AutoModel, AutoConfig from huggingface_hub import snapshot_download from typing import Dict class BGEM3InferenceModel(nn.Module): def __init__( self, model_name: str = "BAAI/bge-m3", colbert_dim: int = -1, ) -> None: super().__init__() model_name = snapshot_download( repo_id=model_name, allow_patterns=[ "model.safetensors", "colbert_linear.pt", "sparse_linear.pt", "config.json", ], ) self.config = AutoConfig.from_pretrained(model_name) self.model = AutoModel.from_pretrained(model_name) self.colbert_linear = torch.nn.Linear( in_features=self.model.config.hidden_size, out_features=( self.model.config.hidden_size if colbert_dim == -1 else colbert_dim ), ) self.sparse_linear = torch.nn.Linear( in_features=self.model.config.hidden_size, out_features=1 ) colbert_state_dict = torch.load( os.path.join(model_name, "colbert_linear.pt"), map_location="cpu" ) sparse_state_dict = torch.load( os.path.join(model_name, "sparse_linear.pt"), map_location="cpu" ) self.colbert_linear.load_state_dict(colbert_state_dict) self.sparse_linear.load_state_dict(sparse_state_dict) def dense_embedding(self, last_hidden_state: Tensor) -> Tensor: return last_hidden_state[:, 0] def sparse_embedding(self, last_hidden_state: Tensor) -> Tensor: with torch.no_grad(): return torch.relu(self.sparse_linear(last_hidden_state)) def colbert_embedding( self, last_hidden_state: Tensor, attention_mask: Tensor ) -> Tensor: with torch.no_grad(): colbert_vecs = self.colbert_linear(last_hidden_state[:, 1:]) colbert_vecs = colbert_vecs * attention_mask[:, 1:][:, :, None].float() return colbert_vecs def forward(self, input_ids: Tensor, attention_mask: Tensor) -> Dict[str, Tensor]: with torch.no_grad(): last_hidden_state = self.model( input_ids=input_ids, attention_mask=attention_mask, return_dict=True ).last_hidden_state output = {} dense_vecs = self.dense_embedding(last_hidden_state) output["dense_vecs"] = torch.nn.functional.normalize(dense_vecs, dim=-1) sparse_vecs = self.sparse_embedding(last_hidden_state) output["sparse_vecs"] = sparse_vecs colbert_vecs = self.colbert_embedding(last_hidden_state, attention_mask) output["colbert_vecs"] = torch.nn.functional.normalize(colbert_vecs, dim=-1) return output