File size: 2,113 Bytes
ed4d993
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
from abc import ABC, abstractmethod
from typing import Dict, List

from scipy.sparse import csr_array  # type: ignore


class BaseSparseEmbedding(ABC):
    """Interface for Sparse embedding models.
    You can inherit from it and implement your custom sparse embedding model.
    """

    @abstractmethod
    def embed_query(self, query: str) -> Dict[int, float]:
        """Embed query text."""

    @abstractmethod
    def embed_documents(self, texts: List[str]) -> List[Dict[int, float]]:
        """Embed search docs."""


class BM25SparseEmbedding(BaseSparseEmbedding):
    """This is a class that inherits BaseSparseEmbedding
    and implements a sparse vector embedding model based on BM25.
    This class uses the BM25 model in Milvus model to implement sparse vector embedding.
    This model requires pymilvus[model] to be installed.
    `pip install pymilvus[model]`
    For more information please refer to:
    https://milvus.io/docs/embed-with-bm25.md
    """

    def __init__(self, corpus: List[str], language: str = "en"):
        from pymilvus.model.sparse import BM25EmbeddingFunction  # type: ignore
        from pymilvus.model.sparse.bm25.tokenizers import (  # type: ignore
            build_default_analyzer,
        )

        self.analyzer = build_default_analyzer(language=language)
        self.bm25_ef = BM25EmbeddingFunction(self.analyzer, num_workers=1)
        self.bm25_ef.fit(corpus)

    def embed_query(self, text: str) -> Dict[int, float]:
        return self._sparse_to_dict(self.bm25_ef.encode_queries([text]))

    def embed_documents(self, texts: List[str]) -> List[Dict[int, float]]:
        sparse_arrays = self.bm25_ef.encode_documents(texts)
        return [self._sparse_to_dict(sparse_array) for sparse_array in sparse_arrays]

    def _sparse_to_dict(self, sparse_array: csr_array) -> Dict[int, float]:
        row_indices, col_indices = sparse_array.nonzero()
        non_zero_values = sparse_array.data
        result_dict = {}
        for col_index, value in zip(col_indices, non_zero_values):
            result_dict[col_index] = value
        return result_dict