Woziii commited on
Commit
3c28324
1 Parent(s): 63afc3f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -26
app.py CHANGED
@@ -4,7 +4,9 @@ from transformers import AutoTokenizer, AutoModelForCausalLM
4
  from huggingface_hub import login
5
  import os
6
  import matplotlib.pyplot as plt
 
7
  import numpy as np
 
8
 
9
  # Authentification
10
  login(token=os.environ["HF_TOKEN"])
@@ -30,18 +32,22 @@ models = [
30
  model = None
31
  tokenizer = None
32
 
33
- def load_model(model_name):
34
  global model, tokenizer
35
  try:
36
- tokenizer = AutoTokenizer.from_pretrained(model_name)
37
- model = AutoModelForCausalLM.from_pretrained(
38
- model_name,
39
- torch_dtype=torch.bfloat16,
40
- device_map="auto",
41
- attn_implementation="eager"
42
- )
43
- if tokenizer.pad_token is None:
44
- tokenizer.pad_token = tokenizer.eos_token
 
 
 
 
45
  return f"Modèle {model_name} chargé avec succès."
46
  except Exception as e:
47
  return f"Erreur lors du chargement du modèle : {str(e)}"
@@ -68,12 +74,10 @@ def analyze_next_token(input_text, temperature, top_p, top_k):
68
 
69
  prob_text = "\n".join([f"{word}: {prob:.4f}" for word, prob in prob_data.items()])
70
 
71
- # Simplification de l'affichage de l'attention
72
- attention_text = "Attention non disponible pour ce modèle"
73
- if hasattr(outputs, 'attentions') and outputs.attentions is not None:
74
- attention_text = "Attention disponible"
75
 
76
- return prob_text, attention_text, prob_plot
77
  except Exception as e:
78
  return f"Erreur lors de l'analyse : {str(e)}", None, None
79
 
@@ -89,16 +93,14 @@ def generate_text(input_text, temperature, top_p, top_k):
89
  with torch.no_grad():
90
  outputs = model.generate(
91
  **inputs,
92
- max_new_tokens=1, # Génère seulement le prochain mot
93
  temperature=temperature,
94
  top_p=top_p,
95
  top_k=top_k
96
  )
97
 
98
  generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
99
- # Ne retourne que le nouveau mot généré
100
- new_word = generated_text[len(input_text):].strip()
101
- return new_word
102
  except Exception as e:
103
  return f"Erreur lors de la génération : {str(e)}"
104
 
@@ -107,7 +109,7 @@ def plot_probabilities(prob_data):
107
  probs = list(prob_data.values())
108
 
109
  fig, ax = plt.subplots(figsize=(10, 5))
110
- ax.bar(words, probs)
111
  ax.set_title("Probabilités des tokens suivants les plus probables")
112
  ax.set_xlabel("Tokens")
113
  ax.set_ylabel("Probabilité")
@@ -115,6 +117,20 @@ def plot_probabilities(prob_data):
115
  plt.tight_layout()
116
  return fig
117
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
  def reset():
119
  global model, tokenizer
120
  model = None
@@ -138,24 +154,25 @@ with gr.Blocks() as demo:
138
  analyze_button = gr.Button("Analyser le prochain token")
139
 
140
  next_token_probs = gr.Textbox(label="Probabilités du prochain token")
141
- attention_info = gr.Textbox(label="Information sur l'attention")
142
 
143
- prob_plot = gr.Plot(label="Probabilités des tokens suivants")
 
 
144
 
145
  generate_button = gr.Button("Générer le prochain mot")
146
- generated_word = gr.Textbox(label="Mot généré")
147
 
148
  reset_button = gr.Button("Réinitialiser")
149
 
150
  load_button.click(load_model, inputs=[model_dropdown], outputs=[load_output])
151
  analyze_button.click(analyze_next_token,
152
  inputs=[input_text, temperature, top_p, top_k],
153
- outputs=[next_token_probs, attention_info, prob_plot])
154
  generate_button.click(generate_text,
155
  inputs=[input_text, temperature, top_p, top_k],
156
- outputs=[generated_word])
157
  reset_button.click(reset,
158
- outputs=[input_text, temperature, top_p, top_k, next_token_probs, attention_info, prob_plot, generated_word])
159
 
160
  if __name__ == "__main__":
161
  demo.launch()
 
4
  from huggingface_hub import login
5
  import os
6
  import matplotlib.pyplot as plt
7
+ import seaborn as sns
8
  import numpy as np
9
+ import time
10
 
11
  # Authentification
12
  login(token=os.environ["HF_TOKEN"])
 
32
  model = None
33
  tokenizer = None
34
 
35
+ def load_model(model_name, progress=gr.Progress()):
36
  global model, tokenizer
37
  try:
38
+ for i in progress.tqdm(range(100)):
39
+ time.sleep(0.01) # Simuler le chargement
40
+ if i == 25:
41
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
42
+ elif i == 75:
43
+ model = AutoModelForCausalLM.from_pretrained(
44
+ model_name,
45
+ torch_dtype=torch.bfloat16,
46
+ device_map="auto",
47
+ attn_implementation="eager"
48
+ )
49
+ if tokenizer.pad_token is None:
50
+ tokenizer.pad_token = tokenizer.eos_token
51
  return f"Modèle {model_name} chargé avec succès."
52
  except Exception as e:
53
  return f"Erreur lors du chargement du modèle : {str(e)}"
 
74
 
75
  prob_text = "\n".join([f"{word}: {prob:.4f}" for word, prob in prob_data.items()])
76
 
77
+ # Alternative pour le mécanisme d'attention
78
+ attention_heatmap = plot_attention_alternative(inputs["input_ids"][0], last_token_logits)
 
 
79
 
80
+ return prob_text, attention_heatmap, prob_plot
81
  except Exception as e:
82
  return f"Erreur lors de l'analyse : {str(e)}", None, None
83
 
 
93
  with torch.no_grad():
94
  outputs = model.generate(
95
  **inputs,
96
+ max_new_tokens=1,
97
  temperature=temperature,
98
  top_p=top_p,
99
  top_k=top_k
100
  )
101
 
102
  generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
103
+ return generated_text # Retourne l'input + le nouveau mot
 
 
104
  except Exception as e:
105
  return f"Erreur lors de la génération : {str(e)}"
106
 
 
109
  probs = list(prob_data.values())
110
 
111
  fig, ax = plt.subplots(figsize=(10, 5))
112
+ sns.barplot(x=words, y=probs, ax=ax)
113
  ax.set_title("Probabilités des tokens suivants les plus probables")
114
  ax.set_xlabel("Tokens")
115
  ax.set_ylabel("Probabilité")
 
117
  plt.tight_layout()
118
  return fig
119
 
120
+ def plot_attention_alternative(input_ids, last_token_logits):
121
+ input_tokens = tokenizer.convert_ids_to_tokens(input_ids)
122
+ attention_scores = torch.nn.functional.softmax(last_token_logits, dim=-1)
123
+ top_k = min(len(input_tokens), 10) # Limiter à 10 tokens pour la lisibilité
124
+ top_attention_scores, _ = torch.topk(attention_scores, top_k)
125
+
126
+ fig, ax = plt.subplots(figsize=(12, 6))
127
+ sns.heatmap(top_attention_scores.unsqueeze(0), annot=True, cmap="YlOrRd", cbar=False, ax=ax)
128
+ ax.set_xticklabels(input_tokens[-top_k:], rotation=45, ha="right")
129
+ ax.set_yticklabels(["Attention"], rotation=0)
130
+ ax.set_title("Scores d'attention pour les derniers tokens")
131
+ plt.tight_layout()
132
+ return fig
133
+
134
  def reset():
135
  global model, tokenizer
136
  model = None
 
154
  analyze_button = gr.Button("Analyser le prochain token")
155
 
156
  next_token_probs = gr.Textbox(label="Probabilités du prochain token")
 
157
 
158
+ with gr.Row():
159
+ attention_plot = gr.Plot(label="Visualisation de l'attention")
160
+ prob_plot = gr.Plot(label="Probabilités des tokens suivants")
161
 
162
  generate_button = gr.Button("Générer le prochain mot")
163
+ generated_text = gr.Textbox(label="Texte généré")
164
 
165
  reset_button = gr.Button("Réinitialiser")
166
 
167
  load_button.click(load_model, inputs=[model_dropdown], outputs=[load_output])
168
  analyze_button.click(analyze_next_token,
169
  inputs=[input_text, temperature, top_p, top_k],
170
+ outputs=[next_token_probs, attention_plot, prob_plot])
171
  generate_button.click(generate_text,
172
  inputs=[input_text, temperature, top_p, top_k],
173
+ outputs=[generated_text])
174
  reset_button.click(reset,
175
+ outputs=[input_text, temperature, top_p, top_k, next_token_probs, attention_plot, prob_plot, generated_text])
176
 
177
  if __name__ == "__main__":
178
  demo.launch()