Model Card for NADI-2024-baseline
A BERT-based model fine-tuned to perform single-label Arabic Dialect Identification (ADI). Instead of predicting the most probable dialect, the logits are used to generate multilabel predictions.
Model Description
- Model type: A Dialect Identification model fine-tuned on the training sets of: NADI2020,2021,2023 and MADAR 2018.
- Language(s) (NLP): Arabic.
- Finetuned from model : MarBERTv2
Multilabel country-level Dialect Identification
Baseline I (Top 90%)
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer
DIALECTS = ["Algeria",
"Bahrain",
"Egypt",
"Iraq",
"Jordan",
"Kuwait",
"Lebanon",
"Libya",
"Morocco",
"Oman",
"Palestine",
"Qatar",
"Saudi_Arabia",
"Sudan",
"Syria",
"Tunisia",
"UAE",
"Yemen",
]
assert len(DIALECTS) == 18
MODEL_NAME = "AMR-KELEG/NADI2024-baseline"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME)
def predict_top_p(text, P=0.9):
"""Predict the top dialects with an accumulative confidence of at least P."""
assert P <= 1 and P >= 0
logits = model(**tokenizer(text, return_tensors="pt")).logits
probabilities = torch.softmax(logits, dim=1).flatten().tolist()
topk_predictions = torch.topk(logits, 18).indices.flatten().tolist()
predictions = [0 for _ in range(18)]
total_prob = 0
for i in range(18):
total_prob += probabilities[topk_predictions[i]]
predictions[topk_predictions[i]] = 1
if total_prob >= P:
break
return [DIALECTS[i] for i, p in enumerate(predictions) if p == 1]
s1 = "كيفك يا زلمة"
s1_pred = predict_top_p(s1) # ['Jordan', 'Lebanon', 'Palestine', 'Syria']
print(s1, s1_pred)
s2 = "خليلي في مساج بريفي كيفاش الاتصال"
s2_pred = predict_top_p(s2) # ['Algeria', 'Tunisia']
print(s2, s2_pred)
- Downloads last month
- 29