Woziii commited on
Commit
9787d82
1 Parent(s): 0db8079

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +47 -14
app.py CHANGED
@@ -6,6 +6,7 @@ import matplotlib.pyplot as plt
6
  import numpy as np
7
  from huggingface_hub import login
8
  import os
 
9
  login(token=os.environ["HF_TOKEN"])
10
 
11
  # Liste des modèles
@@ -50,23 +51,52 @@ def generate_text(input_text, temperature, top_p, top_k):
50
  # Obtenir les logits pour le dernier token généré
51
  last_token_logits = model(outputs.sequences[:, -1:]).logits[:, -1, :]
52
 
 
 
 
 
 
 
 
 
 
 
 
53
  # Extraire les attentions
54
  attentions = outputs.attentions[-1][-1].mean(dim=0).numpy()
55
 
56
- # Visualiser l'attention
57
- plt.figure(figsize=(10, 10))
58
- plt.imshow(attentions, cmap='viridis')
59
- plt.title("Carte d'attention")
60
- attention_plot = plt.gcf()
61
- plt.close()
62
 
63
- # Obtenir les mots les plus probables
64
- probs = torch.nn.functional.softmax(last_token_logits, dim=-1)
65
- top_probs, top_indices = torch.topk(probs[0], k=5)
66
- top_words = [tokenizer.decode([idx.item()]) for idx in top_indices]
 
67
 
68
- return generated_text, attention_plot, top_words
 
 
 
 
 
 
69
 
 
 
 
 
 
 
 
 
 
 
 
70
 
71
  def reset():
72
  return "", 1.0, 1.0, 50, None, None, None
@@ -91,15 +121,18 @@ with gr.Blocks() as demo:
91
 
92
  with gr.Row():
93
  attention_plot = gr.Plot(label="Visualisation de l'attention")
94
- top_words = gr.JSON(label="Mots les plus probables")
95
 
96
  reset_button = gr.Button("Réinitialiser")
97
 
98
  load_button.click(load_model, inputs=[model_dropdown], outputs=[load_output])
99
  generate_button.click(generate_text,
100
  inputs=[input_text, temperature, top_p, top_k],
101
- outputs=[output_text, attention_plot, top_words])
102
  reset_button.click(reset,
103
- outputs=[input_text, temperature, top_p, top_k, output_text, attention_plot, top_words])
 
 
 
104
 
105
  demo.launch()
 
6
  import numpy as np
7
  from huggingface_hub import login
8
  import os
9
+
10
  login(token=os.environ["HF_TOKEN"])
11
 
12
  # Liste des modèles
 
51
  # Obtenir les logits pour le dernier token généré
52
  last_token_logits = model(outputs.sequences[:, -1:]).logits[:, -1, :]
53
 
54
+ # Appliquer softmax pour obtenir les probabilités
55
+ probabilities = torch.nn.functional.softmax(last_token_logits[0], dim=-1)
56
+
57
+ # Obtenir les top 5 tokens les plus probables
58
+ top_k = 5
59
+ top_probs, top_indices = torch.topk(probabilities, top_k)
60
+ top_words = [tokenizer.decode([idx.item()]) for idx in top_indices]
61
+
62
+ # Préparer les données pour le graphique des probabilités
63
+ prob_data = {word: prob.item() for word, prob in zip(top_words, top_probs)}
64
+
65
  # Extraire les attentions
66
  attentions = outputs.attentions[-1][-1].mean(dim=0).numpy()
67
 
68
+ # Préparer les données pour la carte d'attention
69
+ tokens = tokenizer.convert_ids_to_tokens(outputs.sequences[0])
70
+ attention_data = {
71
+ 'attention': attentions.tolist(),
72
+ 'tokens': tokens
73
+ }
74
 
75
+ return generated_text, attention_data, prob_data
76
+
77
+ def plot_attention(attention_data):
78
+ attention = np.array(attention_data['attention'])
79
+ tokens = attention_data['tokens']
80
 
81
+ plt.figure(figsize=(10, 10))
82
+ plt.imshow(attention, cmap='viridis')
83
+ plt.colorbar()
84
+ plt.xticks(range(len(tokens)), tokens, rotation=90)
85
+ plt.yticks(range(len(tokens)), tokens)
86
+ plt.title("Carte d'attention")
87
+ return plt
88
 
89
+ def plot_probabilities(prob_data):
90
+ words = list(prob_data.keys())
91
+ probs = list(prob_data.values())
92
+
93
+ plt.figure(figsize=(10, 5))
94
+ plt.bar(words, probs)
95
+ plt.title("Probabilités des tokens suivants les plus probables")
96
+ plt.xlabel("Tokens")
97
+ plt.ylabel("Probabilité")
98
+ plt.xticks(rotation=45)
99
+ return plt
100
 
101
  def reset():
102
  return "", 1.0, 1.0, 50, None, None, None
 
121
 
122
  with gr.Row():
123
  attention_plot = gr.Plot(label="Visualisation de l'attention")
124
+ prob_plot = gr.Plot(label="Probabilités des tokens suivants")
125
 
126
  reset_button = gr.Button("Réinitialiser")
127
 
128
  load_button.click(load_model, inputs=[model_dropdown], outputs=[load_output])
129
  generate_button.click(generate_text,
130
  inputs=[input_text, temperature, top_p, top_k],
131
+ outputs=[output_text, attention_plot, prob_plot])
132
  reset_button.click(reset,
133
+ outputs=[input_text, temperature, top_p, top_k, output_text, attention_plot, prob_plot])
134
+
135
+ attention_plot.change(plot_attention, inputs=[attention_plot], outputs=[attention_plot])
136
+ prob_plot.change(plot_probabilities, inputs=[prob_plot], outputs=[prob_plot])
137
 
138
  demo.launch()