vietnamese-bi-encoder / pipeline.py
phamson02
add word segmentation before tokenization
c1d85a2
from typing import Dict, List, Union
import torch
from transformers import AutoModel
from custom_tokenizer import CustomPhobertTokenizer
def mean_pooling(model_output, attention_mask):
token_embeddings = model_output[
0
] # First element of model_output contains all token embeddings
input_mask_expanded = (
attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
)
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(
input_mask_expanded.sum(1), min=1e-9
)
class PreTrainedPipeline:
def __init__(self, path="."):
self.model = AutoModel.from_pretrained(path)
self.tokenizer = CustomPhobertTokenizer.from_pretrained(path)
def __call__(self, inputs: Dict[str, Union[str, List[str]]]) -> List[float]:
"""
Args:
inputs (Dict[str, Union[str, List[str]]]):
a dictionary containing a query sentence and a list of key sentences
"""
# Combine the query sentence and key sentences into one list
sentences = [inputs["source_sentence"]] + inputs["sentences"]
# Tokenize sentences
encoded_input = self.tokenizer(
sentences, padding=True, truncation=True, return_tensors="pt"
)
# Compute token embeddings
with torch.no_grad():
model_output = self.model(**encoded_input)
# Perform pooling to get sentence embeddings
sentence_embeddings = mean_pooling(
model_output, encoded_input["attention_mask"]
)
# Separate the query embedding from the key embeddings
query_embedding = sentence_embeddings[0]
key_embeddings = sentence_embeddings[1:]
# Compute cosine similarities (or any other comparison method you prefer)
cosine_similarities = torch.nn.functional.cosine_similarity(
query_embedding.unsqueeze(0), key_embeddings
)
# Convert the tensor of cosine similarities to a list of floats
scores = cosine_similarities.tolist()
return scores
if __name__ == "__main__":
inputs = {
"source_sentence": "Anh ấy đang là sinh viên năm cuối",
"sentences": [
"Anh ấy học tại Đại học Bách khoa Hà Nội, chuyên ngành Khoa học máy tính",
"Anh ấy đang làm việc tại nhà máy sản xuất linh kiện điện tử",
"Anh ấy chuẩn bị đi du học nước ngoài",
"Anh ấy sắp mở cửa hàng bán mỹ phẩm",
"Nhà anh ấy có rất nhiều cây cảnh",
],
}
pipeline = PreTrainedPipeline()
res = pipeline(inputs)