TestApp / components /language_identification.py
menikev's picture
Upload 9 files
d2ed505 verified
raw
history blame
730 Bytes
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer
def language_identification(texts):
text = [
texts,
]
model_ckpt = "papluca/xlm-roberta-base-language-detection"
tokenizer = AutoTokenizer.from_pretrained(model_ckpt)
model = AutoModelForSequenceClassification.from_pretrained(model_ckpt)
inputs = tokenizer(text, padding=True, truncation=True, return_tensors="pt")
with torch.no_grad():
logits = model(**inputs).logits
preds = torch.softmax(logits, dim=-1)
# Map raw predictions to languages
id2lang = model.config.id2label
vals, idxs = torch.max(preds, dim=1)
lang_dict = {id2lang[k.item()]: v.item() for k, v in zip(idxs, vals)}
return lang_dict