Edit model card

Model Card for NADI-2024-baseline

A BERT-based model fine-tuned to perform single-label Arabic Dialect Identification (ADI). Instead of predicting the most probable dialect, the logits are used to generate multilabel predictions.

Model Description

  • Model type: A Dialect Identification model fine-tuned on the training sets of: NADI2020,2021,2023 and MADAR 2018.
  • Language(s) (NLP): Arabic.
  • Finetuned from model : MarBERTv2

Multilabel country-level Dialect Identification

Baseline I (Top 90%)

import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer

DIALECTS = ["Algeria",
    "Bahrain",
    "Egypt",
    "Iraq",
    "Jordan",
    "Kuwait",
    "Lebanon",
    "Libya",
    "Morocco",
    "Oman",
    "Palestine",
    "Qatar",
    "Saudi_Arabia",
    "Sudan",
    "Syria",
    "Tunisia",
    "UAE",
    "Yemen",
]
assert len(DIALECTS) == 18

MODEL_NAME = "AMR-KELEG/NADI2024-baseline"

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME)

def predict_top_p(text, P=0.9):
    """Predict the top dialects with an accumulative confidence of at least P."""
    assert P <= 1 and P >= 0

    logits = model(**tokenizer(text, return_tensors="pt")).logits
    probabilities = torch.softmax(logits, dim=1).flatten().tolist()
    topk_predictions = torch.topk(logits, 18).indices.flatten().tolist()

    predictions = [0 for _ in range(18)]
    total_prob = 0

    for i in range(18):
        total_prob += probabilities[topk_predictions[i]]
        predictions[topk_predictions[i]] = 1
        if total_prob >= P:
            break

    return [DIALECTS[i] for i, p in enumerate(predictions) if p == 1]

s1 = "كيفك يا زلمة"
s1_pred = predict_top_p(s1) # ['Jordan', 'Lebanon', 'Palestine', 'Syria']
print(s1, s1_pred)

s2 = "خليلي في مساج بريفي كيفاش الاتصال"
s2_pred = predict_top_p(s2) # ['Algeria', 'Tunisia']
print(s2, s2_pred)
Downloads last month
29
Inference API
This model can be loaded on Inference API (serverless).