aurelien commited on
Commit
845c5fd
·
1 Parent(s): f63cd93

Edit script for GPU

Browse files
Files changed (1) hide show
  1. app.py +78 -4
app.py CHANGED
@@ -7,7 +7,6 @@ import numpy as np
7
  from sentence_transformers import SentenceTransformer
8
  from transformers import pipeline
9
  import torch
10
- from validate_comment_sentiment_tags import analyze_comment # ton code ci-dessus, tu peux aussi le copier ici
11
 
12
  app = FastAPI(title="Comment Validator API")
13
 
@@ -15,14 +14,89 @@ app = FastAPI(title="Comment Validator API")
15
  # 🔹 Chargement des modèles
16
  # =====================================
17
 
18
- device = "mps" if torch.backends.mps.is_available() else "cpu"
 
 
 
 
 
19
  print(f"🧠 Using device: {device}")
20
 
 
21
  text_model = SentenceTransformer("paraphrase-multilingual-MiniLM-L12-v2", device=device)
 
22
  clf = joblib.load("models/classifier.joblib")
 
23
  encoder = joblib.load("models/encoder.joblib")
24
- sentiment_analyzer = pipeline("sentiment-analysis", model="nlptown/bert-base-multilingual-uncased-sentiment", device=-1)
25
- toxicity_analyzer = pipeline("text-classification", model="unitary/toxic-bert", return_all_scores=True, device=-1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
  # =====================================
28
  # 🔸 Modèles de requête/réponse
 
7
  from sentence_transformers import SentenceTransformer
8
  from transformers import pipeline
9
  import torch
 
10
 
11
  app = FastAPI(title="Comment Validator API")
12
 
 
14
  # 🔹 Chargement des modèles
15
  # =====================================
16
 
17
+ if torch.cuda.is_available():
18
+ device = "cuda"
19
+ elif torch.backends.mps.is_available():
20
+ device = "mps" # pour ton Mac local
21
+ else:
22
+ device = "cpu"
23
  print(f"🧠 Using device: {device}")
24
 
25
+ print("Loading model embedding")
26
  text_model = SentenceTransformer("paraphrase-multilingual-MiniLM-L12-v2", device=device)
27
+ print("Loading model classifier")
28
  clf = joblib.load("models/classifier.joblib")
29
+ print("Loading model encoder")
30
  encoder = joblib.load("models/encoder.joblib")
31
+ print("Loading model sentiment-analysis")
32
+ sentiment_analyzer = pipeline("sentiment-analysis", model="nlptown/bert-base-multilingual-uncased-sentiment", device=device)
33
+ print("Loading model toxicity")
34
+ toxicity_analyzer = pipeline("text-classification", model="unitary/toxic-bert", return_all_scores=True, device=device)
35
+
36
+ def analyze_comment(comment: str, category: str, country: str) -> dict:
37
+ reasons = []
38
+
39
+ # --- Analyse du sentiment ---
40
+ try:
41
+ sentiment = sentiment_analyzer(comment[:512])[0]
42
+ label = sentiment["label"]
43
+ score = sentiment["score"]
44
+ except Exception:
45
+ label, score = "unknown", 0.0
46
+
47
+ if "1" in label or "2" in label:
48
+ sentiment_score = -1
49
+ reasons.append("Le ton semble négatif ou insatisfait.")
50
+ elif "4" in label or "5" in label:
51
+ sentiment_score = 1
52
+ else:
53
+ sentiment_score = 0
54
+
55
+ # --- Encodage du texte ---
56
+ X_text = text_model.encode([comment])
57
+
58
+ # --- Encodage catégorie/pays ---
59
+ df_cat = pd.DataFrame([[category, country]], columns=["category", "country"])
60
+ try:
61
+ X_cat = encoder.transform(df_cat)
62
+ except ValueError:
63
+ reasons.append(f"Catégorie ou pays inconnus : {category}, {country}")
64
+ n_features = sum(len(cats) for cats in encoder.categories_)
65
+ X_cat = np.zeros((1, n_features))
66
+
67
+ # --- Concaténation ---
68
+ X = np.concatenate([X_text, X_cat], axis=1)
69
+
70
+ # --- Prédiction validité ---
71
+ proba = clf.predict_proba(X)[0][1]
72
+ prediction = proba >= 0.5
73
+
74
+ if len(comment.split()) < 3:
75
+ reasons.append("Le commentaire est trop court.")
76
+ if sentiment_score < 0:
77
+ reasons.append("Le ton global est négatif.")
78
+ if proba < 0.4:
79
+ reasons.append("Le modèle estime une faible probabilité de validité.")
80
+
81
+ # --- Analyse toxicité ---
82
+ try:
83
+ tox_scores = toxicity_analyzer(comment[:512])[0] # tronquer pour sécurité
84
+ tags = {f"tag_{item['label']}": round(item['score'], 3) for item in tox_scores}
85
+ except Exception:
86
+ tags = {f"tag_{label}": 0.0 for label in ["toxicity","severe_toxicity","obscene","identity_attack","insult","threat"]}
87
+
88
+ # --- Résultat final ---
89
+ result = {
90
+ "is_valid": bool(prediction),
91
+ "confidence": round(float(proba), 3),
92
+ "sentiment": label,
93
+ "sentiment_score": round(float(score), 3),
94
+ "reasons": "; ".join(reasons) if reasons else "Aucune anomalie détectée."
95
+ }
96
+
97
+ result.update(tags)
98
+ return result
99
+
100
 
101
  # =====================================
102
  # 🔸 Modèles de requête/réponse