Spaces:
Running
Running
File size: 3,786 Bytes
a249916 4f33285 a249916 47cf512 a249916 47cf512 a249916 47cf512 a249916 47cf512 a249916 47cf512 a249916 47cf512 a249916 47cf512 a249916 47cf512 a249916 251bfda a249916 47cf512 a249916 47cf512 a249916 47cf512 a249916 47cf512 a249916 47cf512 a249916 47cf512 a249916 47cf512 a249916 47cf512 a249916 6deb98d a249916 251bfda 47cf512 251bfda |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 |
import abc
from typing import List, Union
from numpy.typing import NDArray
from sentence_transformers import SentenceTransformer
from .type_aliases import ENCODER_DEVICE_TYPE
class Encoder(abc.ABC):
@abc.abstractmethod
def encode(
self,
prediction: List[str],
*,
device: ENCODER_DEVICE_TYPE = "cpu",
batch_size: int = 32,
verbose: bool = False,
) -> NDArray:
"""
Abstract method to encode a list of sentences into sentence embeddings.
Args:
prediction (List[str]): List of sentences to encode.
device (Union[str, int, List[Union[str, int]]]): Device specification for encoding.
batch_size (int): Batch size for encoding.
verbose (bool): Whether to print verbose information during encoding.
Returns:
NDArray: Array of sentence embeddings with shape (num_sentences, embedding_dim).
Raises:
NotImplementedError: If the method is not implemented in the subclass.
"""
raise NotImplementedError("Method 'encode' must be implemented in subclass.")
class SBertEncoder(Encoder):
def __init__(self, model_name: str):
"""
Initialize SBertEncoder instance.
Args:
model_name (str): Name or path of the Sentence Transformer model.
"""
self.model = SentenceTransformer(model_name, trust_remote_code=True)
def encode(
self,
prediction: List[str],
*,
device: ENCODER_DEVICE_TYPE = "cpu",
batch_size: int = 32,
verbose: bool = False,
) -> NDArray:
"""
Encode a list of sentences into sentence embeddings.
Args:
prediction (List[str]): List of sentences to encode.
device (Union[str, int, List[Union[str, int]]]): Device specification for encoding
batch_size (int): Batch size for encoding.
verbose (bool): Whether to print verbose information during encoding.
Returns:
NDArray: Array of sentence embeddings with shape (num_sentences, embedding_dim).
"""
# SBert output is always Batch x Dim
if isinstance(device, list):
# Use multiprocess encoding for list of devices
pool = self.model.start_multi_process_pool(target_devices=device)
embeddings = self.model.encode_multi_process(
prediction, pool=pool, batch_size=batch_size
)
self.model.stop_multi_process_pool(pool)
else:
# Single device encoding
embeddings = self.model.encode(
prediction,
device=device,
batch_size=batch_size,
show_progress_bar=verbose,
)
return embeddings
def get_encoder(model_name: str) -> Encoder:
"""
Get the encoder instance based on the specified model name.
Args:
model_name (str): Name of the model to instantiate
Options:
paraphrase-distilroberta-base-v1,
stsb-roberta-large,
sentence-transformers/use-cmlm-multilingual
Furthermore, you can use any model on Huggingface/SentenceTransformer that is supported by
SentenceTransformer.
Returns:
Encoder: Instance of the selected encoder based on the model_name.
Raises:
EnvironmentError/RuntimeError: If an unsupported model_name is provided.
"""
try:
encoder = SBertEncoder(model_name) # , device, batch_size, verbose)
except EnvironmentError as err:
raise EnvironmentError(str(err)) from None
except Exception as err:
raise RuntimeError(str(err)) from None
return encoder
|