File size: 1,362 Bytes
3975d0a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
# modified from https://github.com/FlagOpen/FlagEmbedding/blob/master/FlagEmbedding/BGE_M3/modeling.py

import os
import torch
from torch import nn, Tensor
from transformers import AutoModel, AutoConfig
from huggingface_hub import snapshot_download
from typing import Dict


class BGEM3InferenceModel(nn.Module):
    def __init__(
        self,
        model_name: str = "BAAI/bge-m3",
        colbert_dim: int = -1,
    ) -> None:
        super().__init__()

        model_name = snapshot_download(
            repo_id=model_name,
            allow_patterns=[
                "pytorch_model.bin",
                "config.json",
            ],
        )

        self.config = AutoConfig.from_pretrained(model_name)
        self.model = AutoModel.from_pretrained(model_name)

    def dense_embedding(self, last_hidden_state: Tensor) -> Tensor:
        return last_hidden_state[:, 0]

    def forward(self, input_ids: Tensor, attention_mask: Tensor) -> Dict[str, Tensor]:
        with torch.no_grad():
            last_hidden_state = self.model(
                input_ids=input_ids, attention_mask=attention_mask, return_dict=True
            ).last_hidden_state

        output = {}
        dense_vecs = self.dense_embedding(last_hidden_state)
        output["dense_vecs"] =  dense_vecs # torch.nn.functional.normalize(dense_vecs, dim=-1)

        return output