File size: 3,786 Bytes
a249916
 
 
 
 
 
4f33285
a249916
 
 
 
47cf512
 
 
 
 
 
 
 
a249916
47cf512
a249916
47cf512
 
 
 
 
a249916
47cf512
 
a249916
47cf512
 
a249916
 
 
 
 
47cf512
a249916
47cf512
a249916
47cf512
 
a249916
251bfda
a249916
47cf512
 
 
 
 
 
 
 
a249916
47cf512
a249916
47cf512
 
 
 
 
a249916
47cf512
 
a249916
 
 
47cf512
a249916
47cf512
 
 
 
a249916
 
 
 
 
47cf512
 
 
a249916
 
 
 
47cf512
a249916
6deb98d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a249916
251bfda
47cf512
251bfda
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import abc
from typing import List, Union

from numpy.typing import NDArray
from sentence_transformers import SentenceTransformer

from .type_aliases import ENCODER_DEVICE_TYPE


class Encoder(abc.ABC):
    @abc.abstractmethod
    def encode(
        self,
        prediction: List[str],
        *,
        device: ENCODER_DEVICE_TYPE = "cpu",
        batch_size: int = 32,
        verbose: bool = False,
    ) -> NDArray:
        """
        Abstract method to encode a list of sentences into sentence embeddings.

        Args:
            prediction (List[str]): List of sentences to encode.
            device (Union[str, int, List[Union[str, int]]]): Device specification for encoding.
            batch_size (int): Batch size for encoding.
            verbose (bool): Whether to print verbose information during encoding.

        Returns:
            NDArray: Array of sentence embeddings with shape (num_sentences, embedding_dim).

        Raises:
            NotImplementedError: If the method is not implemented in the subclass.
        """
        raise NotImplementedError("Method 'encode' must be implemented in subclass.")


class SBertEncoder(Encoder):
    def __init__(self, model_name: str):
        """
        Initialize SBertEncoder instance.

        Args:
            model_name (str): Name or path of the Sentence Transformer model.
        """
        self.model = SentenceTransformer(model_name, trust_remote_code=True)

    def encode(
        self,
        prediction: List[str],
        *,
        device: ENCODER_DEVICE_TYPE = "cpu",
        batch_size: int = 32,
        verbose: bool = False,
    ) -> NDArray:
        """
        Encode a list of sentences into sentence embeddings.

        Args:
            prediction (List[str]): List of sentences to encode.
            device (Union[str, int, List[Union[str, int]]]): Device specification for encoding
            batch_size (int): Batch size for encoding.
            verbose (bool): Whether to print verbose information during encoding.

        Returns:
            NDArray: Array of sentence embeddings with shape (num_sentences, embedding_dim).
        """

        # SBert output is always Batch x Dim
        if isinstance(device, list):
            # Use multiprocess encoding for list of devices
            pool = self.model.start_multi_process_pool(target_devices=device)
            embeddings = self.model.encode_multi_process(
                prediction, pool=pool, batch_size=batch_size
            )
            self.model.stop_multi_process_pool(pool)
        else:
            # Single device encoding
            embeddings = self.model.encode(
                prediction,
                device=device,
                batch_size=batch_size,
                show_progress_bar=verbose,
            )
        return embeddings


def get_encoder(model_name: str) -> Encoder:
    """
    Get the encoder instance based on the specified model name.

    Args:
        model_name (str): Name of the model to instantiate
            Options:
                paraphrase-distilroberta-base-v1,
                stsb-roberta-large,
                sentence-transformers/use-cmlm-multilingual
            Furthermore, you can use any model on Huggingface/SentenceTransformer that is supported by
            SentenceTransformer.

    Returns:
        Encoder: Instance of the selected encoder based on the model_name.

    Raises:
        EnvironmentError/RuntimeError: If an unsupported model_name is provided.
    """

    try:
        encoder = SBertEncoder(model_name)  # , device, batch_size, verbose)
    except EnvironmentError as err:
        raise EnvironmentError(str(err)) from None
    except Exception as err:
        raise RuntimeError(str(err)) from None

    return encoder