File size: 818 Bytes
168a4de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import torch
from transformers import BertModel, BertTokenizerFast
import torch.nn.functional as F

class LaBSE:
    def __init__(self):
        self.tokenizer = BertTokenizerFast.from_pretrained("setu4993/LaBSE")
        self.model = BertModel.from_pretrained("setu4993/LaBSE")
        self.model.eval()

    @torch.no_grad()
    def __call__(self, sentences):
        if not isinstance(sentences, list):
            sentences = [sentences]
        tokens = self.tokenizer(sentences, return_tensors="pt", padding=True)
        outputs = self.model(**tokens)
        embeddings = outputs.pooler_output
        return F.normalize(embeddings, p=2).cpu().numpy()

    @property
    def dim(self):
        return 768

if __name__ == "__main__":
    labse = LaBSE()
    print(labse(["odi et amo", "quare id faciam"]).shape)