File size: 1,200 Bytes
eb6d478
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import torch
from transformers import AutoTokenizer
import torch.nn as nn
from bert_classification import CustomBert  # Importer le modèle depuis le fichier bert_classification.py

tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')

def load_model(model_path, num_classes):
    model = CustomBert(n_classes=num_classes)  # Adapter ici le nombre de classes
    model.load_state_dict(torch.load(model_path))
    model.eval()
    return model

def predict_category(headline, article, model, labels_dict, max_length=100):
    text = headline + " " + article
    inputs = tokenizer(
        text,
        padding="max_length",
        max_length=max_length,
        truncation=True,
        return_tensors="pt"
    )

    with torch.no_grad():
        output = model(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"])
        probabilities = nn.Softmax(dim=1)(output)
        _, pred = torch.max(probabilities, dim=1)
        score = probabilities[0][pred].item()

        inv_labels_dict = {v: k for k, v in labels_dict.items()}
        category = inv_labels_dict[pred.item()]

    score = round(score, 2) 

    return category, score