Spaces:
Sleeping
Sleeping
Merwan6
commited on
Commit
·
0cebe35
1
Parent(s):
fc22127
Commit initial
Browse files- .DS_Store +0 -0
- .gitignore +5 -0
- app.py +74 -0
- push_model.py +12 -0
- readme.md +77 -0
- requirements.txt +84 -0
- scripts/inference.py +127 -0
- scripts/train.py +137 -0
- scripts/utils.py +31 -0
.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
.gitignore
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/venv
|
2 |
+
/models
|
3 |
+
push_models.py
|
4 |
+
hf_login.py
|
5 |
+
/scripts/__pycache__
|
app.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import pandas as pd
|
3 |
+
from scripts.inference import (
|
4 |
+
zero_shot_inference,
|
5 |
+
few_shot_inference,
|
6 |
+
base_model_inference,
|
7 |
+
fine_tuned_inference
|
8 |
+
)
|
9 |
+
|
10 |
+
def predict_with_model(text, model_type):
|
11 |
+
"""
|
12 |
+
Applique la stratégie de classification sélectionnée sur un texte donné
|
13 |
+
et retourne la catégorie prédite avec les scores de confiance.
|
14 |
+
|
15 |
+
Args:
|
16 |
+
text (str): Le texte à analyser (actualité).
|
17 |
+
model_type (str): Le type de modèle sélectionné ("Zero-shot", "Few-shot", etc.).
|
18 |
+
|
19 |
+
Returns:
|
20 |
+
tuple:
|
21 |
+
- str: Catégorie prédite.
|
22 |
+
- pandas.DataFrame: Tableau des scores de confiance par classe.
|
23 |
+
"""
|
24 |
+
|
25 |
+
#Sélection du modèle d'inférence en fonction du choix utilisateur
|
26 |
+
if model_type == "Zero-shot":
|
27 |
+
prediction, scores = zero_shot_inference(text)
|
28 |
+
elif model_type == "Few-shot":
|
29 |
+
prediction, scores = few_shot_inference(text)
|
30 |
+
elif model_type == "Fine-tuned":
|
31 |
+
prediction, scores = fine_tuned_inference(text)
|
32 |
+
elif model_type == "Base model":
|
33 |
+
prediction, scores = base_model_inference(text)
|
34 |
+
else:
|
35 |
+
return "Modèle inconnu", pd.DataFrame()
|
36 |
+
|
37 |
+
#Convertit les scores (dict) en DataFrame pour affichage dans Gradio
|
38 |
+
scores_df = pd.DataFrame([
|
39 |
+
{"Classe": label, "Score": score} for label, score in scores.items()
|
40 |
+
])
|
41 |
+
|
42 |
+
return prediction, scores_df
|
43 |
+
|
44 |
+
#Définition de l'interface utilisateur avec Gradio
|
45 |
+
iface = gr.Interface(
|
46 |
+
fn=predict_with_model, #Fonction appelée au clic de l'utilisateur
|
47 |
+
inputs=[
|
48 |
+
gr.Textbox(
|
49 |
+
lines=4,
|
50 |
+
placeholder="Entrez une phrase d'actualité ici...",
|
51 |
+
label="Texte à classifier"
|
52 |
+
),
|
53 |
+
gr.Radio(
|
54 |
+
choices=["Base model", "Zero-shot", "Few-shot", "Fine-tuned"],
|
55 |
+
label="Choisir le modèle",
|
56 |
+
value="Base model" #Valeur par défaut
|
57 |
+
)
|
58 |
+
],
|
59 |
+
outputs=[
|
60 |
+
gr.Label(label="Catégorie prédite"), #Affiche la prédiction principale
|
61 |
+
gr.BarPlot( #Affiche les scores de confiance
|
62 |
+
label="Scores de confiance",
|
63 |
+
x="Classe",
|
64 |
+
y="Score",
|
65 |
+
color="Classe"
|
66 |
+
)
|
67 |
+
],
|
68 |
+
title="Classification AG News (4 stratégies)",
|
69 |
+
description="Comparer un modèle préentraîné, Zero-shot, Few-shot et Fine-tuned sur AG News"
|
70 |
+
)
|
71 |
+
|
72 |
+
#Lancement de l'application
|
73 |
+
if __name__ == "__main__":
|
74 |
+
iface.launch()
|
push_model.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
2 |
+
|
3 |
+
model_path = "models/fine_tuned_model" # Chemin vers ton modèle fine-tuné
|
4 |
+
repo_name = "agnews-finetuned-bert" # Nom public sur Hugging Face
|
5 |
+
|
6 |
+
# Charger le modèle et le tokenizer
|
7 |
+
model = AutoModelForSequenceClassification.from_pretrained(model_path)
|
8 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
9 |
+
|
10 |
+
# Uploader
|
11 |
+
model.push_to_hub(repo_name)
|
12 |
+
tokenizer.push_to_hub(repo_name)
|
readme.md
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# 📰 AG News Text Classification Demo
|
2 |
+
|
3 |
+
Ce projet présente une application de classification de textes d’actualité basée sur le dataset **AG News**.
|
4 |
+
L'objectif est de comparer plusieurs stratégies d'inférence de modèles Transformers pour la classification de texte.
|
5 |
+
|
6 |
+
---
|
7 |
+
|
8 |
+
## 🚀 Démo en ligne
|
9 |
+
|
10 |
+
L’application est disponible ici :
|
11 |
+
[**Lien vers la démo Hugging Face Space**](https://huggingface.co/spaces/TON_UTILISATEUR/TON_ESPACE) *(à remplacer par ton lien)*
|
12 |
+
|
13 |
+
---
|
14 |
+
|
15 |
+
## 📂 Organisation du projet
|
16 |
+
|
17 |
+
- `app.py` : interface Gradio avec deux onglets (`Demo` + `Documentation`)
|
18 |
+
- `scripts/inference.py` : fonctions d’inférence pour 4 types de modèles
|
19 |
+
- `scripts/train.py` : script d’entraînement du modèle BERT fine-tuné sur AG News
|
20 |
+
- `scripts/utils.py` : calcul des métriques d’évaluation (accuracy, F1, etc.)
|
21 |
+
- `requirements.txt` : liste des dépendances Python
|
22 |
+
|
23 |
+
---
|
24 |
+
|
25 |
+
## 🧠 Description des modèles utilisés
|
26 |
+
|
27 |
+
1. **Base model**
|
28 |
+
Modèle BERT préentraîné `textattack/bert-base-uncased-ag-news` utilisé directement sans fine-tuning.
|
29 |
+
|
30 |
+
2. **Zero-shot**
|
31 |
+
Modèle `facebook/bart-large-mnli` utilisé pour classification zero-shot via pipeline Hugging Face.
|
32 |
+
|
33 |
+
3. **Few-shot**
|
34 |
+
Approche zero-shot avec exemples dans le prompt (prompt engineering).
|
35 |
+
|
36 |
+
4. **Fine-tuned model**
|
37 |
+
Modèle BERT `bert-base-uncased` entraîné sur un sous-ensemble équilibré du dataset AG News (3000 exemples par classe), sauvegardé sur Hugging Face Hub sous `Merwan611/agnews-finetuned-bert`.
|
38 |
+
|
39 |
+
---
|
40 |
+
|
41 |
+
## 📊 Données et entraînement
|
42 |
+
|
43 |
+
- **Dataset** : AG News (4 classes : World, Sports, Business, Sci/Tech)
|
44 |
+
- **Préprocessing** : tokenisation avec `AutoTokenizer` BERT
|
45 |
+
- **Entraînement** : 3 epochs, batch size 32, métrique optimisée : accuracy
|
46 |
+
- **Sous-échantillonnage** pour accélérer l’entraînement : 3000 exemples par classe pour le train, 1000 par classe pour le test
|
47 |
+
|
48 |
+
---
|
49 |
+
|
50 |
+
## 📈 Performances
|
51 |
+
|
52 |
+
Les métriques calculées sont :
|
53 |
+
- Accuracy
|
54 |
+
- Precision (moyenne pondérée)
|
55 |
+
- Recall (moyenne pondérée)
|
56 |
+
- F1-score (moyenne pondérée)
|
57 |
+
|
58 |
+
Le modèle fine-tuné atteint généralement une meilleure précision que le modèle de base ou les approches zero-shot.
|
59 |
+
|
60 |
+
---
|
61 |
+
|
62 |
+
## ⚙️ Lancer l’application localement
|
63 |
+
|
64 |
+
1. Cloner le repo
|
65 |
+
2. Créer un environnement virtuel Python
|
66 |
+
3. Installer les dépendances :
|
67 |
+
```bash
|
68 |
+
pip install -r requirements.txt
|
69 |
+
4. Lancer l'application python app.py
|
70 |
+
|
71 |
+
## ✍️ Auteur
|
72 |
+
Réalisé par Merwan BOUDRIAS dans le cadre d’une démonstration technique.
|
73 |
+
|
74 |
+
## 📚 Références
|
75 |
+
Dataset AG News : https://huggingface.co/datasets/ag_news
|
76 |
+
Modèles Transformers : https://huggingface.co/models
|
77 |
+
Documentation Gradio : https://gradio.app/
|
requirements.txt
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
accelerate==1.7.0
|
2 |
+
aiofiles==24.1.0
|
3 |
+
aiohappyeyeballs==2.6.1
|
4 |
+
aiohttp==3.12.9
|
5 |
+
aiosignal==1.3.2
|
6 |
+
annotated-types==0.7.0
|
7 |
+
anyio==4.9.0
|
8 |
+
attrs==25.3.0
|
9 |
+
certifi==2025.4.26
|
10 |
+
charset-normalizer==3.4.2
|
11 |
+
click==8.2.1
|
12 |
+
datasets==3.6.0
|
13 |
+
dill==0.3.8
|
14 |
+
dotenv==0.9.9
|
15 |
+
fastapi==0.115.12
|
16 |
+
ffmpy==0.6.0
|
17 |
+
filelock==3.18.0
|
18 |
+
frozenlist==1.6.2
|
19 |
+
fsspec==2025.3.0
|
20 |
+
gradio==5.33.0
|
21 |
+
gradio_client==1.10.2
|
22 |
+
groovy==0.1.2
|
23 |
+
h11==0.16.0
|
24 |
+
hf-xet==1.1.3
|
25 |
+
httpcore==1.0.9
|
26 |
+
httpx==0.28.1
|
27 |
+
huggingface-hub==0.32.4
|
28 |
+
idna==3.10
|
29 |
+
Jinja2==3.1.6
|
30 |
+
joblib==1.5.1
|
31 |
+
markdown-it-py==3.0.0
|
32 |
+
MarkupSafe==3.0.2
|
33 |
+
mdurl==0.1.2
|
34 |
+
mpmath==1.3.0
|
35 |
+
multidict==6.4.4
|
36 |
+
multiprocess==0.70.16
|
37 |
+
networkx==3.5
|
38 |
+
numpy==2.2.6
|
39 |
+
orjson==3.10.18
|
40 |
+
packaging==25.0
|
41 |
+
pandas==2.3.0
|
42 |
+
pillow==11.2.1
|
43 |
+
propcache==0.3.1
|
44 |
+
protobuf==6.31.1
|
45 |
+
psutil==7.0.0
|
46 |
+
pyarrow==20.0.0
|
47 |
+
pydantic==2.11.5
|
48 |
+
pydantic_core==2.33.2
|
49 |
+
pydub==0.25.1
|
50 |
+
Pygments==2.19.1
|
51 |
+
python-dateutil==2.9.0.post0
|
52 |
+
python-dotenv==1.1.0
|
53 |
+
python-multipart==0.0.20
|
54 |
+
pytz==2025.2
|
55 |
+
PyYAML==6.0.2
|
56 |
+
regex==2024.11.6
|
57 |
+
requests==2.32.3
|
58 |
+
rich==14.0.0
|
59 |
+
ruff==0.11.12
|
60 |
+
safehttpx==0.1.6
|
61 |
+
safetensors==0.5.3
|
62 |
+
scikit-learn==1.6.1
|
63 |
+
scipy==1.15.3
|
64 |
+
semantic-version==2.10.0
|
65 |
+
shellingham==1.5.4
|
66 |
+
six==1.17.0
|
67 |
+
sniffio==1.3.1
|
68 |
+
starlette==0.46.2
|
69 |
+
sympy==1.14.0
|
70 |
+
threadpoolctl==3.6.0
|
71 |
+
tokenizers==0.21.1
|
72 |
+
tomlkit==0.13.3
|
73 |
+
torch==2.7.1
|
74 |
+
tqdm==4.67.1
|
75 |
+
transformers==4.52.4
|
76 |
+
typer==0.16.0
|
77 |
+
typing-inspection==0.4.1
|
78 |
+
typing_extensions==4.14.0
|
79 |
+
tzdata==2025.2
|
80 |
+
urllib3==2.4.0
|
81 |
+
uvicorn==0.34.3
|
82 |
+
websockets==15.0.1
|
83 |
+
xxhash==3.5.0
|
84 |
+
yarl==1.20.0
|
scripts/inference.py
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
from dotenv import load_dotenv
|
4 |
+
import os
|
5 |
+
from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification
|
6 |
+
|
7 |
+
#Mapping entre les ID des classes et les labels textuels
|
8 |
+
id2label = {0: "World", 1: "Sports", 2: "Business", 3: "Sci/Tech"}
|
9 |
+
|
10 |
+
|
11 |
+
def zero_shot_inference(text):
|
12 |
+
"""
|
13 |
+
Effectue une classification zero-shot à l'aide du modèle BART MNLI.
|
14 |
+
|
15 |
+
Args:
|
16 |
+
text (str): Texte à classifier.
|
17 |
+
|
18 |
+
Returns:
|
19 |
+
tuple:
|
20 |
+
- str: Label prédit.
|
21 |
+
- dict: Dictionnaire {label: score} pour chaque classe.
|
22 |
+
"""
|
23 |
+
classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
|
24 |
+
candidate_labels = list(id2label.values())
|
25 |
+
result = classifier(text, candidate_labels)
|
26 |
+
prediction = result["labels"][0]
|
27 |
+
# Formatage des scores avec 4 décimales
|
28 |
+
scores = {label: float(f"{score:.4f}") for label, score in zip(result["labels"], result["scores"])}
|
29 |
+
return prediction, scores
|
30 |
+
|
31 |
+
|
32 |
+
def few_shot_inference(text):
|
33 |
+
"""
|
34 |
+
Simule un few-shot learning en injectant des exemples dans le prompt (type prompt engineering).
|
35 |
+
|
36 |
+
Args:
|
37 |
+
text (str): Texte à classifier.
|
38 |
+
|
39 |
+
Returns:
|
40 |
+
tuple:
|
41 |
+
- str: Label prédit.
|
42 |
+
- dict: Scores pour chaque classe.
|
43 |
+
"""
|
44 |
+
classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
|
45 |
+
|
46 |
+
#Exemples donnés au modèle pour le guider (prompt engineering)
|
47 |
+
examples = [
|
48 |
+
("The president met the UN delegation to discuss global peace.", "World"),
|
49 |
+
("The football team won their match last night.", "Sports"),
|
50 |
+
("The company reported a big profit this quarter.", "Business"),
|
51 |
+
("New research in AI shows promising results.", "Sci/Tech")
|
52 |
+
]
|
53 |
+
|
54 |
+
#Construction du prompt avec des exemples
|
55 |
+
prompt = ""
|
56 |
+
for example_text, example_label in examples:
|
57 |
+
prompt += f"Text: {example_text}\nLabel: {example_label}\n\n"
|
58 |
+
prompt += f"Text: {text}\nLabel:"
|
59 |
+
|
60 |
+
candidate_labels = list(id2label.values())
|
61 |
+
result = classifier(prompt, candidate_labels)
|
62 |
+
prediction = result["labels"][0]
|
63 |
+
scores = {label: float(f"{score:.4f}") for label, score in zip(result["labels"], result["scores"])}
|
64 |
+
return prediction, scores
|
65 |
+
|
66 |
+
|
67 |
+
def base_model_inference(text):
|
68 |
+
"""
|
69 |
+
Utilise un modèle BERT préentraîné sur AG News (sans fine-tuning personnalisé).
|
70 |
+
|
71 |
+
Args:
|
72 |
+
text (str): Texte à classifier.
|
73 |
+
|
74 |
+
Returns:
|
75 |
+
tuple:
|
76 |
+
- str: Label prédit.
|
77 |
+
- dict: Scores softmax par classe.
|
78 |
+
"""
|
79 |
+
model_name = "textattack/bert-base-uncased-ag-news"
|
80 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
81 |
+
model = AutoModelForSequenceClassification.from_pretrained(model_name)
|
82 |
+
|
83 |
+
#Encodage du texte
|
84 |
+
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
|
85 |
+
|
86 |
+
#Prédiction sans calcul de gradients
|
87 |
+
with torch.no_grad():
|
88 |
+
outputs = model(**inputs)
|
89 |
+
|
90 |
+
#Calcul des probabilités avec softmax
|
91 |
+
probs = F.softmax(outputs.logits, dim=-1)[0].cpu().numpy()
|
92 |
+
|
93 |
+
pred_id = probs.argmax()
|
94 |
+
prediction = id2label[pred_id]
|
95 |
+
scores = {id2label[i]: float(f"{p:.4f}") for i, p in enumerate(probs)}
|
96 |
+
return prediction, scores
|
97 |
+
|
98 |
+
|
99 |
+
def fine_tuned_inference(text, model_path="Merwan611/agnews-finetuned-bert"):
|
100 |
+
"""
|
101 |
+
Utilise un modèle BERT fine-tuné personnalisé sur AG News, avec authentification Hugging Face si nécessaire.
|
102 |
+
|
103 |
+
Args:
|
104 |
+
text (str): Texte à classifier.
|
105 |
+
model_path (str): Nom du modèle Hugging Face ou chemin local.
|
106 |
+
|
107 |
+
Returns:
|
108 |
+
tuple:
|
109 |
+
- str: Label prédit.
|
110 |
+
- dict: Scores softmax par classe.
|
111 |
+
"""
|
112 |
+
|
113 |
+
#Récupération du token d'auth depuis les variables d'environnement
|
114 |
+
token = os.getenv("CLE")
|
115 |
+
|
116 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path, use_auth_token=token)
|
117 |
+
model = AutoModelForSequenceClassification.from_pretrained(model_path)
|
118 |
+
|
119 |
+
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
|
120 |
+
with torch.no_grad():
|
121 |
+
outputs = model(**inputs)
|
122 |
+
|
123 |
+
probs = F.softmax(outputs.logits, dim=-1)[0].cpu().numpy()
|
124 |
+
pred_id = probs.argmax()
|
125 |
+
prediction = id2label[pred_id]
|
126 |
+
scores = {id2label[i]: float(f"{p:.4f}") for i, p in enumerate(probs)}
|
127 |
+
return prediction, scores
|
scripts/train.py
ADDED
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from datasets import load_dataset, DatasetDict, Dataset
|
2 |
+
from transformers import (
|
3 |
+
AutoTokenizer, AutoModelForSequenceClassification,
|
4 |
+
Trainer, TrainingArguments, DataCollatorWithPadding
|
5 |
+
)
|
6 |
+
import numpy as np
|
7 |
+
from utils import compute_metrics
|
8 |
+
import os
|
9 |
+
|
10 |
+
def load_ag_news():
|
11 |
+
"""
|
12 |
+
Charge le jeu de données AG News via Hugging Face datasets.
|
13 |
+
|
14 |
+
Returns:
|
15 |
+
DatasetDict: Contenant les splits train/test.
|
16 |
+
"""
|
17 |
+
dataset = load_dataset("ag_news")
|
18 |
+
return dataset
|
19 |
+
|
20 |
+
|
21 |
+
def get_balanced_subset(dataset_split, n_per_class=1000):
|
22 |
+
"""
|
23 |
+
Crée un sous-ensemble équilibré contenant `n_per_class` exemples par classe.
|
24 |
+
|
25 |
+
Args:
|
26 |
+
dataset_split (Dataset): Split de type train ou test.
|
27 |
+
n_per_class (int): Nombre d'exemples à garder par classe.
|
28 |
+
|
29 |
+
Returns:
|
30 |
+
Dataset: Sous-ensemble équilibré.
|
31 |
+
"""
|
32 |
+
subsets = []
|
33 |
+
|
34 |
+
for label in range(4):
|
35 |
+
#Filtrage des exemples correspondant à la classe `label`
|
36 |
+
filtered = dataset_split.filter(lambda example: example['label'] == label)
|
37 |
+
#Sélection des n premiers exemples (ou tous s’il y en a moins)
|
38 |
+
subsets.append(filtered.select(range(min(n_per_class, len(filtered)))))
|
39 |
+
|
40 |
+
#Fusionner les sous-ensembles
|
41 |
+
combined_dict = {
|
42 |
+
key: sum([subset[key] for subset in subsets], []) for key in subsets[0].features.keys()
|
43 |
+
}
|
44 |
+
|
45 |
+
return Dataset.from_dict(combined_dict)
|
46 |
+
|
47 |
+
|
48 |
+
def preprocess_data(dataset, tokenizer):
|
49 |
+
"""
|
50 |
+
Tokenise le jeu de données avec troncature et padding.
|
51 |
+
|
52 |
+
Args:
|
53 |
+
dataset (DatasetDict): Données d'entraînement et de test.
|
54 |
+
tokenizer (AutoTokenizer): Tokenizer à utiliser.
|
55 |
+
|
56 |
+
Returns:
|
57 |
+
DatasetDict: Données tokenisées.
|
58 |
+
"""
|
59 |
+
def preprocess(batch):
|
60 |
+
return tokenizer(batch["text"], truncation=True, padding=True)
|
61 |
+
|
62 |
+
return dataset.map(preprocess, batched=True)
|
63 |
+
|
64 |
+
|
65 |
+
def main():
|
66 |
+
"""
|
67 |
+
Lance le fine-tuning du modèle BERT sur AG News et sauvegarde le modèle.
|
68 |
+
"""
|
69 |
+
#Création des dossiers de sortie
|
70 |
+
os.makedirs("../models/fine_tuned_model", exist_ok=True)
|
71 |
+
os.makedirs("../logs", exist_ok=True)
|
72 |
+
|
73 |
+
#Chargement du jeu de données
|
74 |
+
dataset = load_ag_news()
|
75 |
+
|
76 |
+
#Création de sous-ensembles équilibrés (entraînement/test)
|
77 |
+
train_subset = get_balanced_subset(dataset["train"], n_per_class=3000)
|
78 |
+
test_subset = get_balanced_subset(dataset["test"], n_per_class=1000)
|
79 |
+
|
80 |
+
dataset_small = DatasetDict({
|
81 |
+
"train": train_subset,
|
82 |
+
"test": test_subset
|
83 |
+
})
|
84 |
+
|
85 |
+
#Chargement du tokenizer
|
86 |
+
model_name = "bert-base-uncased"
|
87 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
88 |
+
|
89 |
+
#Prétraitement (tokenisation)
|
90 |
+
encoded = preprocess_data(dataset_small, tokenizer)
|
91 |
+
|
92 |
+
#Préparation des entrées avec padding dynamique
|
93 |
+
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
|
94 |
+
|
95 |
+
#Chargement du modèle BERT pour classification
|
96 |
+
model = AutoModelForSequenceClassification.from_pretrained(
|
97 |
+
model_name,
|
98 |
+
num_labels=4 #AG News contient 4 classes
|
99 |
+
)
|
100 |
+
|
101 |
+
#Configuration de l'entraînement
|
102 |
+
training_args = TrainingArguments(
|
103 |
+
output_dir="../models/fine_tuned_model",
|
104 |
+
eval_strategy="epoch",
|
105 |
+
save_strategy="epoch",
|
106 |
+
num_train_epochs=3,
|
107 |
+
per_device_train_batch_size=32,
|
108 |
+
per_device_eval_batch_size=32,
|
109 |
+
load_best_model_at_end=True,
|
110 |
+
metric_for_best_model="accuracy",
|
111 |
+
logging_dir="../logs",
|
112 |
+
seed=42
|
113 |
+
)
|
114 |
+
|
115 |
+
#Définition du trainer Hugging Face
|
116 |
+
trainer = Trainer(
|
117 |
+
model=model,
|
118 |
+
args=training_args,
|
119 |
+
train_dataset=encoded["train"],
|
120 |
+
eval_dataset=encoded["test"],
|
121 |
+
tokenizer=tokenizer,
|
122 |
+
data_collator=data_collator,
|
123 |
+
compute_metrics=lambda p: compute_metrics(
|
124 |
+
np.argmax(p.predictions, axis=1), p.label_ids
|
125 |
+
)
|
126 |
+
)
|
127 |
+
|
128 |
+
#Lancement de l'entraînement
|
129 |
+
trainer.train()
|
130 |
+
|
131 |
+
#Sauvegarde finale du modèle
|
132 |
+
trainer.save_model("../models/fine_tuned_model")
|
133 |
+
print("✅ Modèle sauvegardé dans ../models/fine_tuned_model")
|
134 |
+
|
135 |
+
|
136 |
+
if __name__ == "__main__":
|
137 |
+
main()
|
scripts/utils.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
|
2 |
+
|
3 |
+
def compute_metrics(preds, labels):
|
4 |
+
"""
|
5 |
+
Calcule les métriques de classification à partir des prédictions du modèle
|
6 |
+
et des labels de vérité terrain (vrais).
|
7 |
+
|
8 |
+
Args:
|
9 |
+
preds (array-like): Les classes prédites par le modèle (entiers).
|
10 |
+
labels (array-like): Les vraies classes associées aux exemples (entiers).
|
11 |
+
|
12 |
+
Returns:
|
13 |
+
dict: Dictionnaire contenant les métriques suivantes :
|
14 |
+
- "accuracy" : exactitude globale des prédictions
|
15 |
+
- "f1" : score F1 pondéré (par classe)
|
16 |
+
- "precision" : précision pondérée
|
17 |
+
- "recall" : rappel pondéré
|
18 |
+
"""
|
19 |
+
|
20 |
+
#Calcule précision, rappel et F1 pondérés selon la taille de chaque classe
|
21 |
+
precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='weighted')
|
22 |
+
|
23 |
+
#Calcule l'accuracy brute
|
24 |
+
acc = accuracy_score(labels, preds)
|
25 |
+
|
26 |
+
return {
|
27 |
+
"accuracy": acc,
|
28 |
+
"f1": f1,
|
29 |
+
"precision": precision,
|
30 |
+
"recall": recall
|
31 |
+
}
|