File size: 1,027 Bytes
24cf6c5
 
 
 
 
1a452c0
 
 
 
24cf6c5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import torch
from constants import DIALECTS, DIALECTS_WITH_LABELS


def predict_top_p(model, tokenizer, text, P=0.9):
    """Predict the top dialects with an accumulative confidence of at least P (set by default to 0.9).
    The model is expected to generate logits for each dialect of the following dialects in the same order:
    Algeria, Bahrain, Egypt, Iraq, Jordan, Kuwait, Lebanon, Libya, Morocco, Oman, Palestine, Qatar, Saudi_Arabia, Sudan, Syria, Tunisia, UAE, Yemen.
    """
    assert P <= 1 and P >= 0

    logits = model(**tokenizer(text, return_tensors="pt")).logits
    probabilities = torch.softmax(logits, dim=1).flatten().tolist()
    topk_predictions = torch.topk(logits, 18).indices.flatten().tolist()

    predictions = [0 for _ in range(18)]
    total_prob = 0

    for i in range(18):
        total_prob += probabilities[topk_predictions[i]]
        predictions[topk_predictions[i]] = 1
        if total_prob >= P:
            break

    return [DIALECTS[i] for i, p in enumerate(predictions) if p == 1]