Federico Galatolo
first commit
168a4de
raw
history blame
No virus
818 Bytes
import torch
from transformers import BertModel, BertTokenizerFast
import torch.nn.functional as F
class LaBSE:
def __init__(self):
self.tokenizer = BertTokenizerFast.from_pretrained("setu4993/LaBSE")
self.model = BertModel.from_pretrained("setu4993/LaBSE")
self.model.eval()
@torch.no_grad()
def __call__(self, sentences):
if not isinstance(sentences, list):
sentences = [sentences]
tokens = self.tokenizer(sentences, return_tensors="pt", padding=True)
outputs = self.model(**tokens)
embeddings = outputs.pooler_output
return F.normalize(embeddings, p=2).cpu().numpy()
@property
def dim(self):
return 768
if __name__ == "__main__":
labse = LaBSE()
print(labse(["odi et amo", "quare id faciam"]).shape)