Merwan6 commited on
Commit
0cebe35
·
1 Parent(s): fc22127

Commit initial

Browse files
Files changed (9) hide show
  1. .DS_Store +0 -0
  2. .gitignore +5 -0
  3. app.py +74 -0
  4. push_model.py +12 -0
  5. readme.md +77 -0
  6. requirements.txt +84 -0
  7. scripts/inference.py +127 -0
  8. scripts/train.py +137 -0
  9. scripts/utils.py +31 -0
.DS_Store ADDED
Binary file (6.15 kB). View file
 
.gitignore ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ /venv
2
+ /models
3
+ push_models.py
4
+ hf_login.py
5
+ /scripts/__pycache__
app.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import pandas as pd
3
+ from scripts.inference import (
4
+ zero_shot_inference,
5
+ few_shot_inference,
6
+ base_model_inference,
7
+ fine_tuned_inference
8
+ )
9
+
10
+ def predict_with_model(text, model_type):
11
+ """
12
+ Applique la stratégie de classification sélectionnée sur un texte donné
13
+ et retourne la catégorie prédite avec les scores de confiance.
14
+
15
+ Args:
16
+ text (str): Le texte à analyser (actualité).
17
+ model_type (str): Le type de modèle sélectionné ("Zero-shot", "Few-shot", etc.).
18
+
19
+ Returns:
20
+ tuple:
21
+ - str: Catégorie prédite.
22
+ - pandas.DataFrame: Tableau des scores de confiance par classe.
23
+ """
24
+
25
+ #Sélection du modèle d'inférence en fonction du choix utilisateur
26
+ if model_type == "Zero-shot":
27
+ prediction, scores = zero_shot_inference(text)
28
+ elif model_type == "Few-shot":
29
+ prediction, scores = few_shot_inference(text)
30
+ elif model_type == "Fine-tuned":
31
+ prediction, scores = fine_tuned_inference(text)
32
+ elif model_type == "Base model":
33
+ prediction, scores = base_model_inference(text)
34
+ else:
35
+ return "Modèle inconnu", pd.DataFrame()
36
+
37
+ #Convertit les scores (dict) en DataFrame pour affichage dans Gradio
38
+ scores_df = pd.DataFrame([
39
+ {"Classe": label, "Score": score} for label, score in scores.items()
40
+ ])
41
+
42
+ return prediction, scores_df
43
+
44
+ #Définition de l'interface utilisateur avec Gradio
45
+ iface = gr.Interface(
46
+ fn=predict_with_model, #Fonction appelée au clic de l'utilisateur
47
+ inputs=[
48
+ gr.Textbox(
49
+ lines=4,
50
+ placeholder="Entrez une phrase d'actualité ici...",
51
+ label="Texte à classifier"
52
+ ),
53
+ gr.Radio(
54
+ choices=["Base model", "Zero-shot", "Few-shot", "Fine-tuned"],
55
+ label="Choisir le modèle",
56
+ value="Base model" #Valeur par défaut
57
+ )
58
+ ],
59
+ outputs=[
60
+ gr.Label(label="Catégorie prédite"), #Affiche la prédiction principale
61
+ gr.BarPlot( #Affiche les scores de confiance
62
+ label="Scores de confiance",
63
+ x="Classe",
64
+ y="Score",
65
+ color="Classe"
66
+ )
67
+ ],
68
+ title="Classification AG News (4 stratégies)",
69
+ description="Comparer un modèle préentraîné, Zero-shot, Few-shot et Fine-tuned sur AG News"
70
+ )
71
+
72
+ #Lancement de l'application
73
+ if __name__ == "__main__":
74
+ iface.launch()
push_model.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer
2
+
3
+ model_path = "models/fine_tuned_model" # Chemin vers ton modèle fine-tuné
4
+ repo_name = "agnews-finetuned-bert" # Nom public sur Hugging Face
5
+
6
+ # Charger le modèle et le tokenizer
7
+ model = AutoModelForSequenceClassification.from_pretrained(model_path)
8
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
9
+
10
+ # Uploader
11
+ model.push_to_hub(repo_name)
12
+ tokenizer.push_to_hub(repo_name)
readme.md ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 📰 AG News Text Classification Demo
2
+
3
+ Ce projet présente une application de classification de textes d’actualité basée sur le dataset **AG News**.
4
+ L'objectif est de comparer plusieurs stratégies d'inférence de modèles Transformers pour la classification de texte.
5
+
6
+ ---
7
+
8
+ ## 🚀 Démo en ligne
9
+
10
+ L’application est disponible ici :
11
+ [**Lien vers la démo Hugging Face Space**](https://huggingface.co/spaces/TON_UTILISATEUR/TON_ESPACE) *(à remplacer par ton lien)*
12
+
13
+ ---
14
+
15
+ ## 📂 Organisation du projet
16
+
17
+ - `app.py` : interface Gradio avec deux onglets (`Demo` + `Documentation`)
18
+ - `scripts/inference.py` : fonctions d’inférence pour 4 types de modèles
19
+ - `scripts/train.py` : script d’entraînement du modèle BERT fine-tuné sur AG News
20
+ - `scripts/utils.py` : calcul des métriques d’évaluation (accuracy, F1, etc.)
21
+ - `requirements.txt` : liste des dépendances Python
22
+
23
+ ---
24
+
25
+ ## 🧠 Description des modèles utilisés
26
+
27
+ 1. **Base model**
28
+ Modèle BERT préentraîné `textattack/bert-base-uncased-ag-news` utilisé directement sans fine-tuning.
29
+
30
+ 2. **Zero-shot**
31
+ Modèle `facebook/bart-large-mnli` utilisé pour classification zero-shot via pipeline Hugging Face.
32
+
33
+ 3. **Few-shot**
34
+ Approche zero-shot avec exemples dans le prompt (prompt engineering).
35
+
36
+ 4. **Fine-tuned model**
37
+ Modèle BERT `bert-base-uncased` entraîné sur un sous-ensemble équilibré du dataset AG News (3000 exemples par classe), sauvegardé sur Hugging Face Hub sous `Merwan611/agnews-finetuned-bert`.
38
+
39
+ ---
40
+
41
+ ## 📊 Données et entraînement
42
+
43
+ - **Dataset** : AG News (4 classes : World, Sports, Business, Sci/Tech)
44
+ - **Préprocessing** : tokenisation avec `AutoTokenizer` BERT
45
+ - **Entraînement** : 3 epochs, batch size 32, métrique optimisée : accuracy
46
+ - **Sous-échantillonnage** pour accélérer l’entraînement : 3000 exemples par classe pour le train, 1000 par classe pour le test
47
+
48
+ ---
49
+
50
+ ## 📈 Performances
51
+
52
+ Les métriques calculées sont :
53
+ - Accuracy
54
+ - Precision (moyenne pondérée)
55
+ - Recall (moyenne pondérée)
56
+ - F1-score (moyenne pondérée)
57
+
58
+ Le modèle fine-tuné atteint généralement une meilleure précision que le modèle de base ou les approches zero-shot.
59
+
60
+ ---
61
+
62
+ ## ⚙️ Lancer l’application localement
63
+
64
+ 1. Cloner le repo
65
+ 2. Créer un environnement virtuel Python
66
+ 3. Installer les dépendances :
67
+ ```bash
68
+ pip install -r requirements.txt
69
+ 4. Lancer l'application python app.py
70
+
71
+ ## ✍️ Auteur
72
+ Réalisé par Merwan BOUDRIAS dans le cadre d’une démonstration technique.
73
+
74
+ ## 📚 Références
75
+ Dataset AG News : https://huggingface.co/datasets/ag_news
76
+ Modèles Transformers : https://huggingface.co/models
77
+ Documentation Gradio : https://gradio.app/
requirements.txt ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate==1.7.0
2
+ aiofiles==24.1.0
3
+ aiohappyeyeballs==2.6.1
4
+ aiohttp==3.12.9
5
+ aiosignal==1.3.2
6
+ annotated-types==0.7.0
7
+ anyio==4.9.0
8
+ attrs==25.3.0
9
+ certifi==2025.4.26
10
+ charset-normalizer==3.4.2
11
+ click==8.2.1
12
+ datasets==3.6.0
13
+ dill==0.3.8
14
+ dotenv==0.9.9
15
+ fastapi==0.115.12
16
+ ffmpy==0.6.0
17
+ filelock==3.18.0
18
+ frozenlist==1.6.2
19
+ fsspec==2025.3.0
20
+ gradio==5.33.0
21
+ gradio_client==1.10.2
22
+ groovy==0.1.2
23
+ h11==0.16.0
24
+ hf-xet==1.1.3
25
+ httpcore==1.0.9
26
+ httpx==0.28.1
27
+ huggingface-hub==0.32.4
28
+ idna==3.10
29
+ Jinja2==3.1.6
30
+ joblib==1.5.1
31
+ markdown-it-py==3.0.0
32
+ MarkupSafe==3.0.2
33
+ mdurl==0.1.2
34
+ mpmath==1.3.0
35
+ multidict==6.4.4
36
+ multiprocess==0.70.16
37
+ networkx==3.5
38
+ numpy==2.2.6
39
+ orjson==3.10.18
40
+ packaging==25.0
41
+ pandas==2.3.0
42
+ pillow==11.2.1
43
+ propcache==0.3.1
44
+ protobuf==6.31.1
45
+ psutil==7.0.0
46
+ pyarrow==20.0.0
47
+ pydantic==2.11.5
48
+ pydantic_core==2.33.2
49
+ pydub==0.25.1
50
+ Pygments==2.19.1
51
+ python-dateutil==2.9.0.post0
52
+ python-dotenv==1.1.0
53
+ python-multipart==0.0.20
54
+ pytz==2025.2
55
+ PyYAML==6.0.2
56
+ regex==2024.11.6
57
+ requests==2.32.3
58
+ rich==14.0.0
59
+ ruff==0.11.12
60
+ safehttpx==0.1.6
61
+ safetensors==0.5.3
62
+ scikit-learn==1.6.1
63
+ scipy==1.15.3
64
+ semantic-version==2.10.0
65
+ shellingham==1.5.4
66
+ six==1.17.0
67
+ sniffio==1.3.1
68
+ starlette==0.46.2
69
+ sympy==1.14.0
70
+ threadpoolctl==3.6.0
71
+ tokenizers==0.21.1
72
+ tomlkit==0.13.3
73
+ torch==2.7.1
74
+ tqdm==4.67.1
75
+ transformers==4.52.4
76
+ typer==0.16.0
77
+ typing-inspection==0.4.1
78
+ typing_extensions==4.14.0
79
+ tzdata==2025.2
80
+ urllib3==2.4.0
81
+ uvicorn==0.34.3
82
+ websockets==15.0.1
83
+ xxhash==3.5.0
84
+ yarl==1.20.0
scripts/inference.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from dotenv import load_dotenv
4
+ import os
5
+ from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification
6
+
7
+ #Mapping entre les ID des classes et les labels textuels
8
+ id2label = {0: "World", 1: "Sports", 2: "Business", 3: "Sci/Tech"}
9
+
10
+
11
+ def zero_shot_inference(text):
12
+ """
13
+ Effectue une classification zero-shot à l'aide du modèle BART MNLI.
14
+
15
+ Args:
16
+ text (str): Texte à classifier.
17
+
18
+ Returns:
19
+ tuple:
20
+ - str: Label prédit.
21
+ - dict: Dictionnaire {label: score} pour chaque classe.
22
+ """
23
+ classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
24
+ candidate_labels = list(id2label.values())
25
+ result = classifier(text, candidate_labels)
26
+ prediction = result["labels"][0]
27
+ # Formatage des scores avec 4 décimales
28
+ scores = {label: float(f"{score:.4f}") for label, score in zip(result["labels"], result["scores"])}
29
+ return prediction, scores
30
+
31
+
32
+ def few_shot_inference(text):
33
+ """
34
+ Simule un few-shot learning en injectant des exemples dans le prompt (type prompt engineering).
35
+
36
+ Args:
37
+ text (str): Texte à classifier.
38
+
39
+ Returns:
40
+ tuple:
41
+ - str: Label prédit.
42
+ - dict: Scores pour chaque classe.
43
+ """
44
+ classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
45
+
46
+ #Exemples donnés au modèle pour le guider (prompt engineering)
47
+ examples = [
48
+ ("The president met the UN delegation to discuss global peace.", "World"),
49
+ ("The football team won their match last night.", "Sports"),
50
+ ("The company reported a big profit this quarter.", "Business"),
51
+ ("New research in AI shows promising results.", "Sci/Tech")
52
+ ]
53
+
54
+ #Construction du prompt avec des exemples
55
+ prompt = ""
56
+ for example_text, example_label in examples:
57
+ prompt += f"Text: {example_text}\nLabel: {example_label}\n\n"
58
+ prompt += f"Text: {text}\nLabel:"
59
+
60
+ candidate_labels = list(id2label.values())
61
+ result = classifier(prompt, candidate_labels)
62
+ prediction = result["labels"][0]
63
+ scores = {label: float(f"{score:.4f}") for label, score in zip(result["labels"], result["scores"])}
64
+ return prediction, scores
65
+
66
+
67
+ def base_model_inference(text):
68
+ """
69
+ Utilise un modèle BERT préentraîné sur AG News (sans fine-tuning personnalisé).
70
+
71
+ Args:
72
+ text (str): Texte à classifier.
73
+
74
+ Returns:
75
+ tuple:
76
+ - str: Label prédit.
77
+ - dict: Scores softmax par classe.
78
+ """
79
+ model_name = "textattack/bert-base-uncased-ag-news"
80
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
81
+ model = AutoModelForSequenceClassification.from_pretrained(model_name)
82
+
83
+ #Encodage du texte
84
+ inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
85
+
86
+ #Prédiction sans calcul de gradients
87
+ with torch.no_grad():
88
+ outputs = model(**inputs)
89
+
90
+ #Calcul des probabilités avec softmax
91
+ probs = F.softmax(outputs.logits, dim=-1)[0].cpu().numpy()
92
+
93
+ pred_id = probs.argmax()
94
+ prediction = id2label[pred_id]
95
+ scores = {id2label[i]: float(f"{p:.4f}") for i, p in enumerate(probs)}
96
+ return prediction, scores
97
+
98
+
99
+ def fine_tuned_inference(text, model_path="Merwan611/agnews-finetuned-bert"):
100
+ """
101
+ Utilise un modèle BERT fine-tuné personnalisé sur AG News, avec authentification Hugging Face si nécessaire.
102
+
103
+ Args:
104
+ text (str): Texte à classifier.
105
+ model_path (str): Nom du modèle Hugging Face ou chemin local.
106
+
107
+ Returns:
108
+ tuple:
109
+ - str: Label prédit.
110
+ - dict: Scores softmax par classe.
111
+ """
112
+
113
+ #Récupération du token d'auth depuis les variables d'environnement
114
+ token = os.getenv("CLE")
115
+
116
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_auth_token=token)
117
+ model = AutoModelForSequenceClassification.from_pretrained(model_path)
118
+
119
+ inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
120
+ with torch.no_grad():
121
+ outputs = model(**inputs)
122
+
123
+ probs = F.softmax(outputs.logits, dim=-1)[0].cpu().numpy()
124
+ pred_id = probs.argmax()
125
+ prediction = id2label[pred_id]
126
+ scores = {id2label[i]: float(f"{p:.4f}") for i, p in enumerate(probs)}
127
+ return prediction, scores
scripts/train.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import load_dataset, DatasetDict, Dataset
2
+ from transformers import (
3
+ AutoTokenizer, AutoModelForSequenceClassification,
4
+ Trainer, TrainingArguments, DataCollatorWithPadding
5
+ )
6
+ import numpy as np
7
+ from utils import compute_metrics
8
+ import os
9
+
10
+ def load_ag_news():
11
+ """
12
+ Charge le jeu de données AG News via Hugging Face datasets.
13
+
14
+ Returns:
15
+ DatasetDict: Contenant les splits train/test.
16
+ """
17
+ dataset = load_dataset("ag_news")
18
+ return dataset
19
+
20
+
21
+ def get_balanced_subset(dataset_split, n_per_class=1000):
22
+ """
23
+ Crée un sous-ensemble équilibré contenant `n_per_class` exemples par classe.
24
+
25
+ Args:
26
+ dataset_split (Dataset): Split de type train ou test.
27
+ n_per_class (int): Nombre d'exemples à garder par classe.
28
+
29
+ Returns:
30
+ Dataset: Sous-ensemble équilibré.
31
+ """
32
+ subsets = []
33
+
34
+ for label in range(4):
35
+ #Filtrage des exemples correspondant à la classe `label`
36
+ filtered = dataset_split.filter(lambda example: example['label'] == label)
37
+ #Sélection des n premiers exemples (ou tous s’il y en a moins)
38
+ subsets.append(filtered.select(range(min(n_per_class, len(filtered)))))
39
+
40
+ #Fusionner les sous-ensembles
41
+ combined_dict = {
42
+ key: sum([subset[key] for subset in subsets], []) for key in subsets[0].features.keys()
43
+ }
44
+
45
+ return Dataset.from_dict(combined_dict)
46
+
47
+
48
+ def preprocess_data(dataset, tokenizer):
49
+ """
50
+ Tokenise le jeu de données avec troncature et padding.
51
+
52
+ Args:
53
+ dataset (DatasetDict): Données d'entraînement et de test.
54
+ tokenizer (AutoTokenizer): Tokenizer à utiliser.
55
+
56
+ Returns:
57
+ DatasetDict: Données tokenisées.
58
+ """
59
+ def preprocess(batch):
60
+ return tokenizer(batch["text"], truncation=True, padding=True)
61
+
62
+ return dataset.map(preprocess, batched=True)
63
+
64
+
65
+ def main():
66
+ """
67
+ Lance le fine-tuning du modèle BERT sur AG News et sauvegarde le modèle.
68
+ """
69
+ #Création des dossiers de sortie
70
+ os.makedirs("../models/fine_tuned_model", exist_ok=True)
71
+ os.makedirs("../logs", exist_ok=True)
72
+
73
+ #Chargement du jeu de données
74
+ dataset = load_ag_news()
75
+
76
+ #Création de sous-ensembles équilibrés (entraînement/test)
77
+ train_subset = get_balanced_subset(dataset["train"], n_per_class=3000)
78
+ test_subset = get_balanced_subset(dataset["test"], n_per_class=1000)
79
+
80
+ dataset_small = DatasetDict({
81
+ "train": train_subset,
82
+ "test": test_subset
83
+ })
84
+
85
+ #Chargement du tokenizer
86
+ model_name = "bert-base-uncased"
87
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
88
+
89
+ #Prétraitement (tokenisation)
90
+ encoded = preprocess_data(dataset_small, tokenizer)
91
+
92
+ #Préparation des entrées avec padding dynamique
93
+ data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
94
+
95
+ #Chargement du modèle BERT pour classification
96
+ model = AutoModelForSequenceClassification.from_pretrained(
97
+ model_name,
98
+ num_labels=4 #AG News contient 4 classes
99
+ )
100
+
101
+ #Configuration de l'entraînement
102
+ training_args = TrainingArguments(
103
+ output_dir="../models/fine_tuned_model",
104
+ eval_strategy="epoch",
105
+ save_strategy="epoch",
106
+ num_train_epochs=3,
107
+ per_device_train_batch_size=32,
108
+ per_device_eval_batch_size=32,
109
+ load_best_model_at_end=True,
110
+ metric_for_best_model="accuracy",
111
+ logging_dir="../logs",
112
+ seed=42
113
+ )
114
+
115
+ #Définition du trainer Hugging Face
116
+ trainer = Trainer(
117
+ model=model,
118
+ args=training_args,
119
+ train_dataset=encoded["train"],
120
+ eval_dataset=encoded["test"],
121
+ tokenizer=tokenizer,
122
+ data_collator=data_collator,
123
+ compute_metrics=lambda p: compute_metrics(
124
+ np.argmax(p.predictions, axis=1), p.label_ids
125
+ )
126
+ )
127
+
128
+ #Lancement de l'entraînement
129
+ trainer.train()
130
+
131
+ #Sauvegarde finale du modèle
132
+ trainer.save_model("../models/fine_tuned_model")
133
+ print("✅ Modèle sauvegardé dans ../models/fine_tuned_model")
134
+
135
+
136
+ if __name__ == "__main__":
137
+ main()
scripts/utils.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sklearn.metrics import accuracy_score, precision_recall_fscore_support
2
+
3
+ def compute_metrics(preds, labels):
4
+ """
5
+ Calcule les métriques de classification à partir des prédictions du modèle
6
+ et des labels de vérité terrain (vrais).
7
+
8
+ Args:
9
+ preds (array-like): Les classes prédites par le modèle (entiers).
10
+ labels (array-like): Les vraies classes associées aux exemples (entiers).
11
+
12
+ Returns:
13
+ dict: Dictionnaire contenant les métriques suivantes :
14
+ - "accuracy" : exactitude globale des prédictions
15
+ - "f1" : score F1 pondéré (par classe)
16
+ - "precision" : précision pondérée
17
+ - "recall" : rappel pondéré
18
+ """
19
+
20
+ #Calcule précision, rappel et F1 pondérés selon la taille de chaque classe
21
+ precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='weighted')
22
+
23
+ #Calcule l'accuracy brute
24
+ acc = accuracy_score(labels, preds)
25
+
26
+ return {
27
+ "accuracy": acc,
28
+ "f1": f1,
29
+ "precision": precision,
30
+ "recall": recall
31
+ }