File size: 3,840 Bytes
a249916
 
 
 
 
 
4f33285
a249916
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
251bfda
a249916
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6deb98d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a249916
251bfda
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import abc
from typing import List, Union

from numpy.typing import NDArray
from sentence_transformers import SentenceTransformer

from .type_aliases import ENCODER_DEVICE_TYPE


class Encoder(abc.ABC):
    @abc.abstractmethod
    def encode(self, prediction: List[str]) -> NDArray:
        """
            Abstract method to encode a list of sentences into sentence embeddings.

            Args:
                prediction (List[str]): List of sentences to encode.

            Returns:
                NDArray: Array of sentence embeddings with shape (num_sentences, embedding_dim).

            Raises:
                NotImplementedError: If the method is not implemented in the subclass.
        """
        raise NotImplementedError("Method 'encode' must be implemented in subclass.")


class SBertEncoder(Encoder):
    def __init__(self, model_name: str, device: ENCODER_DEVICE_TYPE, batch_size: int, verbose: bool):
        """
            Initialize SBertEncoder instance.

            Args:
                model_name (str): Name or path of the Sentence Transformer model.
                device (Union[str, int, List[Union[str, int]]]): Device specification for encoding
                batch_size (int): Batch size for encoding.
                verbose (bool): Whether to print verbose information during encoding.
        """
        self.model = SentenceTransformer(model_name, trust_remote_code=True)
        self.device = device
        self.batch_size = batch_size
        self.verbose = verbose

    def encode(self, prediction: List[str]) -> NDArray:
        """
           Encode a list of sentences into sentence embeddings.

           Args:
               prediction (List[str]): List of sentences to encode.

           Returns:
               NDArray: Array of sentence embeddings with shape (num_sentences, embedding_dim).
        """

        # SBert output is always Batch x Dim
        if isinstance(self.device, list):
            # Use multiprocess encoding for list of devices
            pool = self.model.start_multi_process_pool(target_devices=self.device)
            embeddings = self.model.encode_multi_process(prediction, pool=pool, batch_size=self.batch_size)
            self.model.stop_multi_process_pool(pool)
        else:
            # Single device encoding
            embeddings = self.model.encode(
                prediction,
                device=self.device,
                batch_size=self.batch_size,
                show_progress_bar=self.verbose,
            )

        return embeddings


def get_encoder(model_name: str, device: ENCODER_DEVICE_TYPE, batch_size: int, verbose: bool) -> Encoder:
    """
    Get the encoder instance based on the specified model name.

    Args:
        model_name (str): Name of the model to instantiate
            Options:
                paraphrase-distilroberta-base-v1,
                stsb-roberta-large,
                sentence-transformers/use-cmlm-multilingual
            Furthermore, you can use any model on Huggingface/SentenceTransformer that is supported by
            SentenceTransformer.

        device (Union[str, int, List[Union[str, int]]): Device specification for the encoder
            (e.g., "cuda", 0 for GPU, "cpu").
        batch_size (int): Batch size for encoding.
        verbose (bool): Whether to print verbose information during encoder initialization.

    Returns:
        Encoder: Instance of the selected encoder based on the model_name.

    Raises:
        EnvironmentError/RuntimeError: If an unsupported model_name is provided.
    """

    try:
        encoder = SBertEncoder(model_name, device, batch_size, verbose)
    except EnvironmentError as err:
        raise EnvironmentError(str(err)) from None
    except Exception as err:
        raise RuntimeError(str(err)) from None

    return encoder