|
from typing import Dict, List, Union |
|
import torch |
|
from transformers import AutoModel |
|
from custom_tokenizer import CustomPhobertTokenizer |
|
|
|
|
|
def mean_pooling(model_output, attention_mask): |
|
token_embeddings = model_output[ |
|
0 |
|
] |
|
input_mask_expanded = ( |
|
attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() |
|
) |
|
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp( |
|
input_mask_expanded.sum(1), min=1e-9 |
|
) |
|
|
|
|
|
class PreTrainedPipeline: |
|
def __init__(self, path="."): |
|
self.model = AutoModel.from_pretrained(path) |
|
self.tokenizer = CustomPhobertTokenizer.from_pretrained(path) |
|
|
|
def __call__(self, inputs: Dict[str, Union[str, List[str]]]) -> List[float]: |
|
""" |
|
Args: |
|
inputs (Dict[str, Union[str, List[str]]]): |
|
a dictionary containing a query sentence and a list of key sentences |
|
""" |
|
|
|
|
|
sentences = [inputs["source_sentence"]] + inputs["sentences"] |
|
|
|
|
|
encoded_input = self.tokenizer( |
|
sentences, padding=True, truncation=True, return_tensors="pt" |
|
) |
|
|
|
|
|
with torch.no_grad(): |
|
model_output = self.model(**encoded_input) |
|
|
|
|
|
sentence_embeddings = mean_pooling( |
|
model_output, encoded_input["attention_mask"] |
|
) |
|
|
|
|
|
query_embedding = sentence_embeddings[0] |
|
key_embeddings = sentence_embeddings[1:] |
|
|
|
|
|
cosine_similarities = torch.nn.functional.cosine_similarity( |
|
query_embedding.unsqueeze(0), key_embeddings |
|
) |
|
|
|
|
|
scores = cosine_similarities.tolist() |
|
|
|
return scores |
|
|
|
|
|
if __name__ == "__main__": |
|
inputs = { |
|
"source_sentence": "Anh ấy đang là sinh viên năm cuối", |
|
"sentences": [ |
|
"Anh ấy học tại Đại học Bách khoa Hà Nội, chuyên ngành Khoa học máy tính", |
|
"Anh ấy đang làm việc tại nhà máy sản xuất linh kiện điện tử", |
|
"Anh ấy chuẩn bị đi du học nước ngoài", |
|
"Anh ấy sắp mở cửa hàng bán mỹ phẩm", |
|
"Nhà anh ấy có rất nhiều cây cảnh", |
|
], |
|
} |
|
|
|
pipeline = PreTrainedPipeline() |
|
res = pipeline(inputs) |
|
|