Spaces:
Paused
Paused
Update app.py
Browse files
app.py
CHANGED
@@ -62,28 +62,21 @@ def analyze_next_token(input_text, temperature, top_p, top_k):
|
|
62 |
|
63 |
try:
|
64 |
with torch.no_grad():
|
65 |
-
outputs = model(**inputs
|
66 |
|
67 |
last_token_logits = outputs.logits[0, -1, :]
|
68 |
probabilities = torch.nn.functional.softmax(last_token_logits, dim=-1)
|
69 |
-
|
70 |
-
# Obtenir les 10 tokens les plus probables
|
71 |
top_k = 10
|
72 |
top_probs, top_indices = torch.topk(probabilities, top_k)
|
73 |
top_words = [tokenizer.decode([idx.item()]).strip() for idx in top_indices]
|
74 |
prob_data = {word: prob.item() for word, prob in zip(top_words, top_probs)}
|
|
|
75 |
|
76 |
-
|
77 |
-
prob_text = "Prochains tokens les plus probables :\n\n"
|
78 |
-
for word, prob in prob_data.items():
|
79 |
-
escaped_word = word.replace("<", "<").replace(">", ">")
|
80 |
-
prob_text += f"{escaped_word}: {prob:.2%}\n"
|
81 |
|
82 |
-
|
83 |
-
prob_plot = plot_probabilities(prob_data)
|
84 |
-
attention_plot = plot_attention(inputs["input_ids"][0], outputs.attentions)
|
85 |
|
86 |
-
return prob_text,
|
87 |
except Exception as e:
|
88 |
return f"Erreur lors de l'analyse : {str(e)}", None, None
|
89 |
|
@@ -115,32 +108,40 @@ def plot_probabilities(prob_data):
|
|
115 |
probs = list(prob_data.values())
|
116 |
|
117 |
fig, ax = plt.subplots(figsize=(12, 6))
|
118 |
-
bars = ax.bar(
|
119 |
-
ax.set_title("Probabilités des tokens suivants les plus probables")
|
120 |
-
ax.set_xlabel("Tokens")
|
121 |
-
ax.set_ylabel("Probabilité")
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
for
|
127 |
height = bar.get_height()
|
128 |
-
ax.text(
|
129 |
-
|
|
|
130 |
|
131 |
plt.tight_layout()
|
132 |
return fig
|
133 |
|
134 |
-
def
|
135 |
input_tokens = tokenizer.convert_ids_to_tokens(input_ids)
|
|
|
|
|
|
|
136 |
|
137 |
-
|
138 |
-
|
|
|
|
|
|
|
139 |
|
140 |
-
|
141 |
-
|
|
|
|
|
142 |
|
143 |
-
ax.set_title("Carte d'attention moyenne")
|
144 |
plt.tight_layout()
|
145 |
return fig
|
146 |
|
@@ -151,7 +152,7 @@ def reset():
|
|
151 |
return "", 1.0, 1.0, 50, None, None, None, None
|
152 |
|
153 |
with gr.Blocks() as demo:
|
154 |
-
gr.Markdown("#
|
155 |
|
156 |
with gr.Accordion("Sélection du modèle"):
|
157 |
model_dropdown = gr.Dropdown(choices=models, label="Choisissez un modèle")
|
|
|
62 |
|
63 |
try:
|
64 |
with torch.no_grad():
|
65 |
+
outputs = model(**inputs)
|
66 |
|
67 |
last_token_logits = outputs.logits[0, -1, :]
|
68 |
probabilities = torch.nn.functional.softmax(last_token_logits, dim=-1)
|
|
|
|
|
69 |
top_k = 10
|
70 |
top_probs, top_indices = torch.topk(probabilities, top_k)
|
71 |
top_words = [tokenizer.decode([idx.item()]).strip() for idx in top_indices]
|
72 |
prob_data = {word: prob.item() for word, prob in zip(top_words, top_probs)}
|
73 |
+
prob_plot = plot_probabilities(prob_data)
|
74 |
|
75 |
+
prob_text = "\n".join([f"{word}: {prob:.2%}" for word, prob in prob_data.items()])
|
|
|
|
|
|
|
|
|
76 |
|
77 |
+
attention_heatmap = plot_attention_alternative(inputs["input_ids"][0], last_token_logits)
|
|
|
|
|
78 |
|
79 |
+
return prob_text, attention_heatmap, prob_plot
|
80 |
except Exception as e:
|
81 |
return f"Erreur lors de l'analyse : {str(e)}", None, None
|
82 |
|
|
|
108 |
probs = list(prob_data.values())
|
109 |
|
110 |
fig, ax = plt.subplots(figsize=(12, 6))
|
111 |
+
bars = ax.bar(words, probs, color='skyblue')
|
112 |
+
ax.set_title("Probabilités des 10 tokens suivants les plus probables", fontsize=16)
|
113 |
+
ax.set_xlabel("Tokens", fontsize=12)
|
114 |
+
ax.set_ylabel("Probabilité", fontsize=12)
|
115 |
+
plt.xticks(rotation=45, ha='right', fontsize=10)
|
116 |
+
plt.yticks(fontsize=10)
|
117 |
+
|
118 |
+
# Ajouter les pourcentages au-dessus des barres
|
119 |
+
for bar in bars:
|
120 |
height = bar.get_height()
|
121 |
+
ax.text(bar.get_x() + bar.get_width()/2., height,
|
122 |
+
f'{height:.2%}',
|
123 |
+
ha='center', va='bottom', fontsize=10)
|
124 |
|
125 |
plt.tight_layout()
|
126 |
return fig
|
127 |
|
128 |
+
def plot_attention_alternative(input_ids, last_token_logits):
|
129 |
input_tokens = tokenizer.convert_ids_to_tokens(input_ids)
|
130 |
+
attention_scores = torch.nn.functional.softmax(last_token_logits, dim=-1)
|
131 |
+
top_k = min(len(input_tokens), 10) # Limiter à 10 tokens pour la lisibilité
|
132 |
+
top_attention_scores, _ = torch.topk(attention_scores, top_k)
|
133 |
|
134 |
+
fig, ax = plt.subplots(figsize=(14, 7))
|
135 |
+
sns.heatmap(top_attention_scores.unsqueeze(0).numpy(), annot=True, cmap="YlOrRd", cbar=True, ax=ax, fmt='.2%')
|
136 |
+
ax.set_xticklabels(input_tokens[-top_k:], rotation=45, ha="right", fontsize=10)
|
137 |
+
ax.set_yticklabels(["Attention"], rotation=0, fontsize=10)
|
138 |
+
ax.set_title("Scores d'attention pour les derniers tokens", fontsize=16)
|
139 |
|
140 |
+
# Ajuster la colorbar
|
141 |
+
cbar = ax.collections[0].colorbar
|
142 |
+
cbar.set_label("Score d'attention", fontsize=12)
|
143 |
+
cbar.ax.tick_params(labelsize=10)
|
144 |
|
|
|
145 |
plt.tight_layout()
|
146 |
return fig
|
147 |
|
|
|
152 |
return "", 1.0, 1.0, 50, None, None, None, None
|
153 |
|
154 |
with gr.Blocks() as demo:
|
155 |
+
gr.Markdown("# Analyse et génération de texte")
|
156 |
|
157 |
with gr.Accordion("Sélection du modèle"):
|
158 |
model_dropdown = gr.Dropdown(choices=models, label="Choisissez un modèle")
|