Woziii commited on
Commit
9386df0
1 Parent(s): 575de15

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -62
app.py CHANGED
@@ -11,55 +11,37 @@ import time
11
  # Authentification
12
  login(token=os.environ["HF_TOKEN"])
13
 
14
- # Structure hiérarchique des modèles
15
- models_hierarchy = {
16
- "meta-llama": {
17
- "Llama-2": ["7B", "13B", "70B"],
18
- "Llama-3": ["8B", "3.2B", "3.1B"]
19
- },
20
- "mistralai": {
21
- "Mistral": ["7B-v0.1", "7B-v0.3"],
22
- "Mixtral": ["8x7B-v0.1"]
23
- },
24
- "google": {
25
- "Gemma": ["2B", "9B", "27B"]
26
- },
27
- "croissantllm": {
28
- "CroissantLLM": ["Base"]
29
- }
30
- }
31
-
32
- # Langues supportées par modèle
33
- models_languages = {
34
- "meta-llama/Llama-2-7B": ["en"],
35
- "meta-llama/Llama-2-13B": ["en"],
36
- "meta-llama/Llama-2-70B": ["en"],
37
- "meta-llama/Llama-3-8B": ["en"],
38
- "meta-llama/Llama-3-3.2B": ["en", "de", "fr", "it", "pt", "hi", "es", "th"],
39
- "meta-llama/Llama-3-3.1B": ["en", "de", "fr", "it", "pt", "hi", "es", "th"],
40
  "mistralai/Mistral-7B-v0.1": ["en"],
41
  "mistralai/Mixtral-8x7B-v0.1": ["en", "fr", "it", "de", "es"],
42
  "mistralai/Mistral-7B-v0.3": ["en"],
43
- "google/Gemma-2B": ["en"],
44
- "google/Gemma-9B": ["en"],
45
- "google/Gemma-27B": ["en"],
46
  "croissantllm/CroissantLLMBase": ["en", "fr"]
47
  }
48
 
49
  # Paramètres recommandés pour chaque modèle
50
  model_parameters = {
51
- "meta-llama/Llama-2-7B": {"temperature": 0.8, "top_p": 0.9, "top_k": 40},
52
- "meta-llama/Llama-2-13B": {"temperature": 0.8, "top_p": 0.9, "top_k": 40},
53
- "meta-llama/Llama-2-70B": {"temperature": 0.8, "top_p": 0.9, "top_k": 40},
54
- "meta-llama/Llama-3-8B": {"temperature": 0.75, "top_p": 0.9, "top_k": 50},
55
- "meta-llama/Llama-3-3.2B": {"temperature": 0.75, "top_p": 0.9, "top_k": 50},
56
- "meta-llama/Llama-3-3.1B": {"temperature": 0.75, "top_p": 0.9, "top_k": 50},
57
  "mistralai/Mistral-7B-v0.1": {"temperature": 0.7, "top_p": 0.9, "top_k": 50},
58
  "mistralai/Mixtral-8x7B-v0.1": {"temperature": 0.8, "top_p": 0.95, "top_k": 50},
59
  "mistralai/Mistral-7B-v0.3": {"temperature": 0.7, "top_p": 0.9, "top_k": 50},
60
- "google/Gemma-2B": {"temperature": 0.7, "top_p": 0.95, "top_k": 40},
61
- "google/Gemma-9B": {"temperature": 0.7, "top_p": 0.95, "top_k": 40},
62
- "google/Gemma-27B": {"temperature": 0.7, "top_p": 0.95, "top_k": 40},
63
  "croissantllm/CroissantLLMBase": {"temperature": 0.8, "top_p": 0.92, "top_k": 50}
64
  }
65
 
@@ -68,31 +50,24 @@ model = None
68
  tokenizer = None
69
  selected_language = None
70
 
71
- def update_model_choices(company):
72
- return list(models_hierarchy[company].keys())
73
-
74
- def update_variation_choices(company, model_name):
75
- return models_hierarchy[company][model_name]
76
-
77
- def load_model(company, model_name, variation, progress=gr.Progress()):
78
  global model, tokenizer
79
- full_model_name = f"{company}/{model_name}-{variation}"
80
-
81
  try:
82
  progress(0, desc="Chargement du tokenizer")
83
- tokenizer = AutoTokenizer.from_pretrained(full_model_name)
84
  progress(0.5, desc="Chargement du modèle")
85
 
86
- if "mixtral" in full_model_name.lower():
 
87
  model = AutoModelForCausalLM.from_pretrained(
88
- full_model_name,
89
  torch_dtype=torch.float16,
90
  device_map="auto",
91
  load_in_8bit=True
92
  )
93
  else:
94
  model = AutoModelForCausalLM.from_pretrained(
95
- full_model_name,
96
  torch_dtype=torch.float16,
97
  device_map="auto"
98
  )
@@ -101,11 +76,12 @@ def load_model(company, model_name, variation, progress=gr.Progress()):
101
  tokenizer.pad_token = tokenizer.eos_token
102
 
103
  progress(1.0, desc="Modèle chargé")
104
- available_languages = models_languages[full_model_name]
105
 
106
- params = model_parameters[full_model_name]
 
107
  return (
108
- f"Modèle {full_model_name} chargé avec succès. Langues disponibles : {', '.join(available_languages)}",
109
  gr.Dropdown(choices=available_languages, value=available_languages[0], visible=True, interactive=True),
110
  params["temperature"],
111
  params["top_p"],
@@ -120,6 +96,7 @@ def set_language(lang):
120
  return f"Langue sélectionnée : {lang}"
121
 
122
  def ensure_token_display(token):
 
123
  if token.isdigit() or (token.startswith('-') and token[1:].isdigit()):
124
  return tokenizer.decode([int(token)])
125
  return token
@@ -229,9 +206,7 @@ with gr.Blocks() as demo:
229
  gr.Markdown("# LLM&BIAS")
230
 
231
  with gr.Accordion("Sélection du modèle"):
232
- company_dropdown = gr.Dropdown(choices=list(models_hierarchy.keys()), label="Choisissez une société")
233
- model_dropdown = gr.Dropdown(label="Choisissez un modèle", interactive=False)
234
- variation_dropdown = gr.Dropdown(label="Choisissez une variation", interactive=False)
235
  load_button = gr.Button("Charger le modèle")
236
  load_output = gr.Textbox(label="Statut du chargement")
237
  language_dropdown = gr.Dropdown(label="Choisissez une langue", visible=False)
@@ -256,11 +231,8 @@ with gr.Blocks() as demo:
256
 
257
  reset_button = gr.Button("Réinitialiser")
258
 
259
- company_dropdown.change(update_model_choices, inputs=[company_dropdown], outputs=[model_dropdown])
260
- model_dropdown.change(update_variation_choices, inputs=[company_dropdown, model_dropdown], outputs=[variation_dropdown])
261
-
262
  load_button.click(load_model,
263
- inputs=[company_dropdown, model_dropdown, variation_dropdown],
264
  outputs=[load_output, language_dropdown, temperature, top_p, top_k])
265
  language_dropdown.change(set_language, inputs=[language_dropdown], outputs=[language_output])
266
  analyze_button.click(analyze_next_token,
@@ -273,4 +245,4 @@ with gr.Blocks() as demo:
273
  outputs=[input_text, temperature, top_p, top_k, next_token_probs, attention_plot, prob_plot, generated_text, language_dropdown, language_output])
274
 
275
  if __name__ == "__main__":
276
- demo.launch()
 
11
  # Authentification
12
  login(token=os.environ["HF_TOKEN"])
13
 
14
+ # Liste des modèles et leurs langues supportées
15
+ models_and_languages = {
16
+ "meta-llama/Llama-2-13b-hf": ["en"],
17
+ "meta-llama/Llama-2-7b-hf": ["en"],
18
+ "meta-llama/Llama-2-70b-hf": ["en"],
19
+ "meta-llama/Meta-Llama-3-8B": ["en"],
20
+ "meta-llama/Llama-3.2-3B": ["en", "de", "fr", "it", "pt", "hi", "es", "th"],
21
+ "meta-llama/Llama-3.1-8B": ["en", "de", "fr", "it", "pt", "hi", "es", "th"],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  "mistralai/Mistral-7B-v0.1": ["en"],
23
  "mistralai/Mixtral-8x7B-v0.1": ["en", "fr", "it", "de", "es"],
24
  "mistralai/Mistral-7B-v0.3": ["en"],
25
+ "google/gemma-2-2b": ["en"],
26
+ "google/gemma-2-9b": ["en"],
27
+ "google/gemma-2-27b": ["en"],
28
  "croissantllm/CroissantLLMBase": ["en", "fr"]
29
  }
30
 
31
  # Paramètres recommandés pour chaque modèle
32
  model_parameters = {
33
+ "meta-llama/Llama-2-13b-hf": {"temperature": 0.8, "top_p": 0.9, "top_k": 40},
34
+ "meta-llama/Llama-2-7b-hf": {"temperature": 0.8, "top_p": 0.9, "top_k": 40},
35
+ "meta-llama/Llama-2-70b-hf": {"temperature": 0.8, "top_p": 0.9, "top_k": 40},
36
+ "meta-llama/Meta-Llama-3-8B": {"temperature": 0.75, "top_p": 0.9, "top_k": 50},
37
+ "meta-llama/Llama-3.2-3B": {"temperature": 0.75, "top_p": 0.9, "top_k": 50},
38
+ "meta-llama/Llama-3.1-8B": {"temperature": 0.75, "top_p": 0.9, "top_k": 50},
39
  "mistralai/Mistral-7B-v0.1": {"temperature": 0.7, "top_p": 0.9, "top_k": 50},
40
  "mistralai/Mixtral-8x7B-v0.1": {"temperature": 0.8, "top_p": 0.95, "top_k": 50},
41
  "mistralai/Mistral-7B-v0.3": {"temperature": 0.7, "top_p": 0.9, "top_k": 50},
42
+ "google/gemma-2-2b": {"temperature": 0.7, "top_p": 0.95, "top_k": 40},
43
+ "google/gemma-2-9b": {"temperature": 0.7, "top_p": 0.95, "top_k": 40},
44
+ "google/gemma-2-27b": {"temperature": 0.7, "top_p": 0.95, "top_k": 40},
45
  "croissantllm/CroissantLLMBase": {"temperature": 0.8, "top_p": 0.92, "top_k": 50}
46
  }
47
 
 
50
  tokenizer = None
51
  selected_language = None
52
 
53
+ def load_model(model_name, progress=gr.Progress()):
 
 
 
 
 
 
54
  global model, tokenizer
 
 
55
  try:
56
  progress(0, desc="Chargement du tokenizer")
57
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
58
  progress(0.5, desc="Chargement du modèle")
59
 
60
+ # Configurations spécifiques par modèle
61
+ if "mixtral" in model_name.lower():
62
  model = AutoModelForCausalLM.from_pretrained(
63
+ model_name,
64
  torch_dtype=torch.float16,
65
  device_map="auto",
66
  load_in_8bit=True
67
  )
68
  else:
69
  model = AutoModelForCausalLM.from_pretrained(
70
+ model_name,
71
  torch_dtype=torch.float16,
72
  device_map="auto"
73
  )
 
76
  tokenizer.pad_token = tokenizer.eos_token
77
 
78
  progress(1.0, desc="Modèle chargé")
79
+ available_languages = models_and_languages[model_name]
80
 
81
+ # Mise à jour des sliders avec les valeurs recommandées
82
+ params = model_parameters[model_name]
83
  return (
84
+ f"Modèle {model_name} chargé avec succès. Langues disponibles : {', '.join(available_languages)}",
85
  gr.Dropdown(choices=available_languages, value=available_languages[0], visible=True, interactive=True),
86
  params["temperature"],
87
  params["top_p"],
 
96
  return f"Langue sélectionnée : {lang}"
97
 
98
  def ensure_token_display(token):
99
+ """Assure que le token est affiché correctement."""
100
  if token.isdigit() or (token.startswith('-') and token[1:].isdigit()):
101
  return tokenizer.decode([int(token)])
102
  return token
 
206
  gr.Markdown("# LLM&BIAS")
207
 
208
  with gr.Accordion("Sélection du modèle"):
209
+ model_dropdown = gr.Dropdown(choices=list(models_and_languages.keys()), label="Choisissez un modèle")
 
 
210
  load_button = gr.Button("Charger le modèle")
211
  load_output = gr.Textbox(label="Statut du chargement")
212
  language_dropdown = gr.Dropdown(label="Choisissez une langue", visible=False)
 
231
 
232
  reset_button = gr.Button("Réinitialiser")
233
 
 
 
 
234
  load_button.click(load_model,
235
+ inputs=[model_dropdown],
236
  outputs=[load_output, language_dropdown, temperature, top_p, top_k])
237
  language_dropdown.change(set_language, inputs=[language_dropdown], outputs=[language_output])
238
  analyze_button.click(analyze_next_token,
 
245
  outputs=[input_text, temperature, top_p, top_k, next_token_probs, attention_plot, prob_plot, generated_text, language_dropdown, language_output])
246
 
247
  if __name__ == "__main__":
248
+ demo.launch()