Mattimax commited on
Commit
e05153e
·
verified ·
1 Parent(s): 8398ebb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +446 -101
app.py CHANGED
@@ -7,11 +7,12 @@ import pandas as pd
7
 
8
 
9
  # =========================
10
- # Configurazione benchmark
11
  # =========================
12
 
13
  MAX_MODELS = 5
14
- DEFAULT_NUM_SAMPLES = 50 # numero di esempi da usare per il benchmark
 
15
 
16
 
17
  def get_device():
@@ -20,21 +21,87 @@ def get_device():
20
  return "cpu"
21
 
22
 
23
- def load_boolq_dataset(num_samples=DEFAULT_NUM_SAMPLES):
24
- """
25
- Carica un subset del dataset BoolQ.
26
- BoolQ: domande sì/no con un breve contesto.
27
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  ds = load_dataset("boolq", split="validation")
29
  if num_samples is not None and num_samples < len(ds):
30
  ds = ds.select(range(num_samples))
31
  return ds
32
 
33
 
34
- def build_boolq_prompt(passage, question):
35
- """
36
- Prompt in italiano: il modello deve rispondere solo 'sì' o 'no'.
37
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  prompt = (
39
  "Sei un sistema di question answering. "
40
  "Rispondi strettamente solo con 'sì' o 'no'.\n\n"
@@ -45,6 +112,41 @@ def build_boolq_prompt(passage, question):
45
  return prompt
46
 
47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  def parse_yes_no(output_text):
49
  """
50
  Estrae 'sì/si' o 'no' dall'output del modello.
@@ -72,28 +174,63 @@ def parse_yes_no(output_text):
72
  return None
73
 
74
 
75
- def evaluate_model_on_boolq(model_name, num_samples=DEFAULT_NUM_SAMPLES, max_new_tokens=5):
76
  """
77
- Esegue il benchmark di un modello su BoolQ.
78
- Ritorna:
79
- - accuracy
80
- - numero di esempi valutati
81
- - tempo medio per esempio
82
  """
83
- device = get_device()
84
- start_total = time.time()
 
 
 
 
 
 
 
 
 
 
 
85
 
86
- # Caricamento modello e tokenizer
87
- try:
88
- tokenizer = AutoTokenizer.from_pretrained(model_name)
89
- model = AutoModelForCausalLM.from_pretrained(model_name)
90
- except Exception as e:
91
- raise RuntimeError(f"Errore nel caricamento del modello '{model_name}': {e}")
92
 
 
 
 
 
 
 
 
 
 
93
  model.to(device)
94
  model.eval()
 
95
 
96
- ds = load_boolq_dataset(num_samples=num_samples)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
 
98
  correct = 0
99
  total = 0
@@ -104,73 +241,201 @@ def evaluate_model_on_boolq(model_name, num_samples=DEFAULT_NUM_SAMPLES, max_new
104
  question = example["question"]
105
  label = example["answer"] # True/False
106
 
107
- prompt = build_boolq_prompt(passage, question)
108
- inputs = tokenizer(prompt, return_tensors="pt").to(device)
 
 
109
 
110
  t0 = time.time()
111
- with torch.no_grad():
112
- output_ids = model.generate(
113
- **inputs,
114
- max_new_tokens=max_new_tokens,
115
- do_sample=False,
116
- temperature=0.0,
117
- )
118
  t1 = time.time()
119
- gen_text = tokenizer.decode(
120
- output_ids[0][inputs["input_ids"].shape[-1]:],
121
- skip_special_tokens=True,
122
- )
123
 
124
  pred = parse_yes_no(gen_text)
125
 
126
- # Contiamo sempre l'esempio, anche se il modello non risponde in modo valido
127
  total += 1
128
  times.append(t1 - t0)
129
 
130
  if pred is not None and pred == label:
131
  correct += 1
132
 
133
- if total == 0:
134
- accuracy = 0.0
135
- avg_time = None
136
- else:
137
- accuracy = correct / total
138
- avg_time = sum(times) / len(times) if times else None
139
 
140
- total_time = time.time() - start_total
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
 
142
  return {
143
  "model_name": model_name,
 
144
  "num_samples": total,
145
  "accuracy": accuracy,
146
  "avg_time_per_sample_sec": avg_time,
147
- "total_time_sec": total_time,
148
  }
149
 
150
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
  # =========================
152
  # Funzioni per la UI
153
  # =========================
154
 
155
  def add_model_field(current_count):
156
- """
157
- Aumenta il numero di campi modello visibili, fino a MAX_MODELS.
158
- """
159
  if current_count < MAX_MODELS:
160
  current_count += 1
161
  return current_count
162
 
163
 
164
  def get_visible_textboxes(model_count):
165
- """
166
- Ritorna la visibilità dei 5 campi modello in base a model_count.
167
- """
168
  visibility = []
169
  for i in range(1, MAX_MODELS + 1):
170
  visibility.append(gr.update(visible=(i <= model_count)))
171
  return visibility
172
 
173
 
 
 
 
 
 
 
 
 
 
 
 
 
 
174
  def run_benchmark_ui(
175
  model_1,
176
  model_2,
@@ -178,15 +443,15 @@ def run_benchmark_ui(
178
  model_4,
179
  model_5,
180
  model_count,
 
 
 
 
 
 
181
  num_samples,
182
  ):
183
- """
184
- Funzione chiamata dal pulsante 'Esegui benchmark'.
185
- Raccoglie i nomi dei modelli, esegue il benchmark e ritorna:
186
- - tabella risultati
187
- - log testuale
188
- """
189
- # Raccogli i modelli attivi
190
  model_names = []
191
  all_models = [model_1, model_2, model_3, model_4, model_5]
192
  for i in range(model_count):
@@ -194,45 +459,64 @@ def run_benchmark_ui(
194
  if name:
195
  model_names.append(name)
196
 
 
 
 
 
 
 
 
 
197
  if len(model_names) < 2:
198
- return (
199
- pd.DataFrame(),
200
- "Devi specificare almeno due modelli validi."
201
- )
202
 
203
- results = []
204
  logs = []
 
205
 
206
- logs.append(f"Avvio benchmark su BoolQ con {num_samples} esempi...")
207
  logs.append(f"Modelli: {', '.join(model_names)}")
 
208
  logs.append("Device: " + get_device())
209
  logs.append("====================================")
210
 
211
- for name in model_names:
212
- logs.append(f"\n[MODELLO] {name}")
213
  try:
214
- res = evaluate_model_on_boolq(name, num_samples=num_samples)
215
- results.append(res)
216
-
217
- avg_time_str = (
218
- f"{res['avg_time_per_sample_sec']:.3f}"
219
- if res['avg_time_per_sample_sec'] is not None
220
- else "N/A"
221
- )
222
-
223
- logs.append(
224
- f" - Esempi valutati: {res['num_samples']}\n"
225
- f" - Accuracy: {res['accuracy']:.3f}\n"
226
- f" - Tempo medio per esempio (s): {avg_time_str}\n"
227
- f" - Tempo totale (s): {res['total_time_sec']:.3f}"
228
- )
229
  except Exception as e:
230
- logs.append(f" ERRORE: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
231
 
232
  if results:
233
  df = pd.DataFrame(results)
234
- # Ordina per accuracy decrescente
235
- df = df.sort_values(by="accuracy", ascending=False)
236
  else:
237
  df = pd.DataFrame()
238
 
@@ -241,27 +525,33 @@ def run_benchmark_ui(
241
 
242
 
243
  # =========================
244
- # Costruzione interfaccia Gradio
245
  # =========================
246
 
247
- with gr.Blocks(title="LLM Benchmark Space - BoolQ (IT)") as demo:
248
  gr.Markdown(
249
  """
250
- # 🔍 LLM Benchmark Space (BoolQ, IT)
251
 
252
- Inserisci i nomi dei modelli Hugging Face (es. `meta-llama/Meta-Llama-3-8B-Instruct`)
253
- e confrontali su un subset del dataset **BoolQ** (domande sì/no).
254
 
255
  - Minimo **2 modelli**
256
  - Puoi aggiungere fino a **5 modelli** con il pulsante **"+ Aggiungi modello"**
257
- - Output: tabella con **accuracy**, numero di esempi e tempi
258
-
259
- I prompt sono in **italiano** e il modello deve rispondere solo con **"sì"** o **"no"**.
 
 
 
 
 
260
  """
261
  )
262
 
263
  with gr.Row():
264
  with gr.Column():
 
265
  model_count_state = gr.State(value=2)
266
 
267
  model_1 = gr.Textbox(
@@ -295,14 +585,50 @@ with gr.Blocks(title="LLM Benchmark Space - BoolQ (IT)") as demo:
295
  visible=False,
296
  )
297
 
298
- add_button = gr.Button("+ Aggiungi modello")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
299
 
300
  num_samples = gr.Slider(
301
  minimum=10,
302
  maximum=200,
303
  step=10,
304
  value=DEFAULT_NUM_SAMPLES,
305
- label="Numero di esempi BoolQ da usare",
306
  )
307
 
308
  run_button = gr.Button("🚀 Esegui benchmark", variant="primary")
@@ -311,6 +637,7 @@ with gr.Blocks(title="LLM Benchmark Space - BoolQ (IT)") as demo:
311
  results_df = gr.Dataframe(
312
  headers=[
313
  "model_name",
 
314
  "num_samples",
315
  "accuracy",
316
  "avg_time_per_sample_sec",
@@ -321,23 +648,35 @@ with gr.Blocks(title="LLM Benchmark Space - BoolQ (IT)") as demo:
321
  )
322
  logs_box = gr.Textbox(
323
  label="Log esecuzione",
324
- lines=20,
325
  interactive=False,
326
  )
327
 
328
- # Logica pulsante "+ Aggiungi modello"
329
  def on_add_model(model_count):
330
  new_count = add_model_field(model_count)
331
  visibility_updates = get_visible_textboxes(new_count)
332
  return [new_count] + visibility_updates
333
 
334
- add_button.click(
335
  fn=on_add_model,
336
  inputs=[model_count_state],
337
  outputs=[model_count_state, model_1, model_2, model_3, model_4, model_5],
338
  )
339
 
340
- # Logica pulsante "Esegui benchmark"
 
 
 
 
 
 
 
 
 
 
 
 
341
  run_button.click(
342
  fn=run_benchmark_ui,
343
  inputs=[
@@ -347,6 +686,12 @@ with gr.Blocks(title="LLM Benchmark Space - BoolQ (IT)") as demo:
347
  model_4,
348
  model_5,
349
  model_count_state,
 
 
 
 
 
 
350
  num_samples,
351
  ],
352
  outputs=[results_df, logs_box],
 
7
 
8
 
9
  # =========================
10
+ # Configurazione generale
11
  # =========================
12
 
13
  MAX_MODELS = 5
14
+ MAX_DATASETS = 5
15
+ DEFAULT_NUM_SAMPLES = 50 # numero di esempi da usare per ogni dataset
16
 
17
 
18
  def get_device():
 
21
  return "cpu"
22
 
23
 
24
+ # =========================
25
+ # Definizione dataset
26
+ # =========================
27
+
28
+ DATASETS = {
29
+ "boolq_en": {
30
+ "label": "BoolQ (en)",
31
+ "language": "en",
32
+ "description": "Yes/No QA su contesti in inglese",
33
+ },
34
+ "squad_it": {
35
+ "label": "SQuAD-it (it)",
36
+ "language": "it",
37
+ "description": "QA estrattivo in italiano",
38
+ },
39
+ "pawsx_it": {
40
+ "label": "PAWS-X (it)",
41
+ "language": "it",
42
+ "description": "Parafrasi in italiano (stesso significato sì/no)",
43
+ },
44
+ "sentiment_it": {
45
+ "label": "Sentiment-it (it)",
46
+ "language": "it",
47
+ "description": "Sentiment positivo/negativo in italiano",
48
+ },
49
+ }
50
+
51
+ DATASET_LABELS = [cfg["label"] for cfg in DATASETS.values()]
52
+
53
+ LABEL_TO_KEY = {cfg["label"]: key for key, cfg in DATASETS.items()}
54
+
55
+
56
+ # =========================
57
+ # Loader dataset
58
+ # =========================
59
+
60
+ def load_boolq(num_samples=DEFAULT_NUM_SAMPLES):
61
  ds = load_dataset("boolq", split="validation")
62
  if num_samples is not None and num_samples < len(ds):
63
  ds = ds.select(range(num_samples))
64
  return ds
65
 
66
 
67
+ def load_squad_it(num_samples=DEFAULT_NUM_SAMPLES):
68
+ # Nota: se "squad_it" non esiste o ha split diversi, qui puoi adattare.
69
+ ds = load_dataset("squad_it", split="test")
70
+ if num_samples is not None and num_samples < len(ds):
71
+ ds = ds.select(range(num_samples))
72
+ return ds
73
+
74
+
75
+ def load_pawsx_it(num_samples=DEFAULT_NUM_SAMPLES):
76
+ ds = load_dataset("paws-x", "it", split="validation")
77
+ if num_samples is not None and num_samples < len(ds):
78
+ ds = ds.select(range(num_samples))
79
+ return ds
80
+
81
+
82
+ def load_sentiment_it(num_samples=DEFAULT_NUM_SAMPLES):
83
+ ds = load_dataset("sentiment-it", split="train")
84
+ if num_samples is not None and num_samples < len(ds):
85
+ ds = ds.select(range(num_samples))
86
+ return ds
87
+
88
+
89
+ # =========================
90
+ # Prompt & parsing
91
+ # =========================
92
+
93
+ def build_boolq_prompt_en(passage, question):
94
+ prompt = (
95
+ "You are a question answering system. "
96
+ "Answer strictly with 'yes' or 'no'.\n\n"
97
+ f"Passage: {passage}\n"
98
+ f"Question: {question}\n"
99
+ "Answer:"
100
+ )
101
+ return prompt
102
+
103
+
104
+ def build_boolq_prompt_it(passage, question):
105
  prompt = (
106
  "Sei un sistema di question answering. "
107
  "Rispondi strettamente solo con 'sì' o 'no'.\n\n"
 
112
  return prompt
113
 
114
 
115
+ def build_squad_it_prompt(context, question):
116
+ prompt = (
117
+ "Sei un sistema di question answering in italiano. "
118
+ "Rispondi con una breve frase che risponde alla domanda.\n\n"
119
+ f"Contesto: {context}\n"
120
+ f"Domanda: {question}\n"
121
+ "Risposta:"
122
+ )
123
+ return prompt
124
+
125
+
126
+ def build_pawsx_it_prompt(sentence1, sentence2):
127
+ prompt = (
128
+ "Sei un sistema di riconoscimento di parafrasi in italiano.\n"
129
+ "Ti verranno date due frasi. Devi dire se esprimono lo stesso significato.\n"
130
+ "Rispondi strettamente solo con 'sì' o 'no'.\n\n"
131
+ f"Frase 1: {sentence1}\n"
132
+ f"Frase 2: {sentence2}\n"
133
+ "Le due frasi hanno lo stesso significato?\n"
134
+ "Risposta:"
135
+ )
136
+ return prompt
137
+
138
+
139
+ def build_sentiment_it_prompt(text):
140
+ prompt = (
141
+ "Sei un sistema di analisi del sentiment in italiano.\n"
142
+ "Ti verrà dato un testo. Devi dire se il sentiment è positivo o negativo.\n"
143
+ "Rispondi strettamente solo con 'positivo' o 'negativo'.\n\n"
144
+ f"Testo: {text}\n"
145
+ "Sentiment:"
146
+ )
147
+ return prompt
148
+
149
+
150
  def parse_yes_no(output_text):
151
  """
152
  Estrae 'sì/si' o 'no' dall'output del modello.
 
174
  return None
175
 
176
 
177
+ def parse_sentiment_it(output_text):
178
  """
179
+ Ritorna True per positivo, False per negativo, None se non riconosciuto.
 
 
 
 
180
  """
181
+ text = output_text.strip().lower()
182
+ if not text:
183
+ return None
184
+
185
+ first = text.split()[0]
186
+
187
+ if first.startswith("pos"):
188
+ return True
189
+ if first.startswith("neg"):
190
+ return False
191
+
192
+ return None
193
+
194
 
195
+ def normalize_text(s):
196
+ return " ".join(s.strip().lower().split())
 
 
 
 
197
 
198
+
199
+ # =========================
200
+ # Modello: load & generate
201
+ # =========================
202
+
203
+ def load_model(model_name):
204
+ device = get_device()
205
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
206
+ model = AutoModelForCausalLM.from_pretrained(model_name)
207
  model.to(device)
208
  model.eval()
209
+ return tokenizer, model, device
210
 
211
+
212
+ def generate_text(tokenizer, model, device, prompt, max_new_tokens=32):
213
+ inputs = tokenizer(prompt, return_tensors="pt").to(device)
214
+ with torch.no_grad():
215
+ output_ids = model.generate(
216
+ **inputs,
217
+ max_new_tokens=max_new_tokens,
218
+ do_sample=False,
219
+ temperature=0.0,
220
+ )
221
+ gen_text = tokenizer.decode(
222
+ output_ids[0][inputs["input_ids"].shape[-1]:],
223
+ skip_special_tokens=True,
224
+ )
225
+ return gen_text
226
+
227
+
228
+ # =========================
229
+ # Valutazione per dataset
230
+ # =========================
231
+
232
+ def evaluate_on_boolq(model_name, tokenizer, model, device, num_samples=DEFAULT_NUM_SAMPLES, lang="en"):
233
+ ds = load_boolq(num_samples=num_samples)
234
 
235
  correct = 0
236
  total = 0
 
241
  question = example["question"]
242
  label = example["answer"] # True/False
243
 
244
+ if lang == "en":
245
+ prompt = build_boolq_prompt_en(passage, question)
246
+ else:
247
+ prompt = build_boolq_prompt_it(passage, question)
248
 
249
  t0 = time.time()
250
+ gen_text = generate_text(tokenizer, model, device, prompt, max_new_tokens=5)
 
 
 
 
 
 
251
  t1 = time.time()
 
 
 
 
252
 
253
  pred = parse_yes_no(gen_text)
254
 
 
255
  total += 1
256
  times.append(t1 - t0)
257
 
258
  if pred is not None and pred == label:
259
  correct += 1
260
 
261
+ accuracy = correct / total if total > 0 else 0.0
262
+ avg_time = sum(times) / len(times) if times else None
 
 
 
 
263
 
264
+ return {
265
+ "model_name": model_name,
266
+ "dataset": "BoolQ (en)" if lang == "en" else "BoolQ (it)",
267
+ "num_samples": total,
268
+ "accuracy": accuracy,
269
+ "avg_time_per_sample_sec": avg_time,
270
+ }
271
+
272
+
273
+ def evaluate_on_squad_it(model_name, tokenizer, model, device, num_samples=DEFAULT_NUM_SAMPLES):
274
+ ds = load_squad_it(num_samples=num_samples)
275
+
276
+ correct = 0
277
+ total = 0
278
+ times = []
279
+
280
+ for example in ds:
281
+ context = example["context"]
282
+ question = example["question"]
283
+ answers = example.get("answers", {})
284
+ gold_answers = answers.get("text", []) if isinstance(answers, dict) else []
285
+
286
+ prompt = build_squad_it_prompt(context, question)
287
+
288
+ t0 = time.time()
289
+ gen_text = generate_text(tokenizer, model, device, prompt, max_new_tokens=32)
290
+ t1 = time.time()
291
+
292
+ pred = normalize_text(gen_text)
293
+ total += 1
294
+ times.append(t1 - t0)
295
+
296
+ if gold_answers:
297
+ gold_norm = [normalize_text(a) for a in gold_answers]
298
+ if any(g in pred or pred in g for g in gold_norm):
299
+ correct += 1
300
+
301
+ accuracy = correct / total if total > 0 else 0.0
302
+ avg_time = sum(times) / len(times) if times else None
303
+
304
+ return {
305
+ "model_name": model_name,
306
+ "dataset": "SQuAD-it (it)",
307
+ "num_samples": total,
308
+ "accuracy": accuracy,
309
+ "avg_time_per_sample_sec": avg_time,
310
+ }
311
+
312
+
313
+ def evaluate_on_pawsx_it(model_name, tokenizer, model, device, num_samples=DEFAULT_NUM_SAMPLES):
314
+ ds = load_pawsx_it(num_samples=num_samples)
315
+
316
+ correct = 0
317
+ total = 0
318
+ times = []
319
+
320
+ for example in ds:
321
+ s1 = example["sentence1"]
322
+ s2 = example["sentence2"]
323
+ label = example["label"] # 0: non-parafrasi, 1: parafrasi
324
+
325
+ prompt = build_pawsx_it_prompt(s1, s2)
326
+
327
+ t0 = time.time()
328
+ gen_text = generate_text(tokenizer, model, device, prompt, max_new_tokens=5)
329
+ t1 = time.time()
330
+
331
+ pred = parse_yes_no(gen_text)
332
+ total += 1
333
+ times.append(t1 - t0)
334
+
335
+ if pred is not None:
336
+ is_paraphrase = (label == 1)
337
+ if pred == is_paraphrase:
338
+ correct += 1
339
+
340
+ accuracy = correct / total if total > 0 else 0.0
341
+ avg_time = sum(times) / len(times) if times else None
342
+
343
+ return {
344
+ "model_name": model_name,
345
+ "dataset": "PAWS-X (it)",
346
+ "num_samples": total,
347
+ "accuracy": accuracy,
348
+ "avg_time_per_sample_sec": avg_time,
349
+ }
350
+
351
+
352
+ def evaluate_on_sentiment_it(model_name, tokenizer, model, device, num_samples=DEFAULT_NUM_SAMPLES):
353
+ ds = load_sentiment_it(num_samples=num_samples)
354
+
355
+ correct = 0
356
+ total = 0
357
+ times = []
358
+
359
+ for example in ds:
360
+ text = example["text"]
361
+ label = example["label"] # 0: negativo, 1: positivo (tipico schema)
362
+
363
+ prompt = build_sentiment_it_prompt(text)
364
+
365
+ t0 = time.time()
366
+ gen_text = generate_text(tokenizer, model, device, prompt, max_new_tokens=5)
367
+ t1 = time.time()
368
+
369
+ pred = parse_sentiment_it(gen_text)
370
+ total += 1
371
+ times.append(t1 - t0)
372
+
373
+ if pred is not None:
374
+ is_positive = (label == 1)
375
+ if pred == is_positive:
376
+ correct += 1
377
+
378
+ accuracy = correct / total if total > 0 else 0.0
379
+ avg_time = sum(times) / len(times) if times else None
380
 
381
  return {
382
  "model_name": model_name,
383
+ "dataset": "Sentiment-it (it)",
384
  "num_samples": total,
385
  "accuracy": accuracy,
386
  "avg_time_per_sample_sec": avg_time,
 
387
  }
388
 
389
 
390
+ def evaluate_model_on_dataset(model_name, tokenizer, model, device, dataset_key, num_samples):
391
+ start_total = time.time()
392
+
393
+ if dataset_key == "boolq_en":
394
+ res = evaluate_on_boolq(model_name, tokenizer, model, device, num_samples=num_samples, lang="en")
395
+ elif dataset_key == "squad_it":
396
+ res = evaluate_on_squad_it(model_name, tokenizer, model, device, num_samples=num_samples)
397
+ elif dataset_key == "pawsx_it":
398
+ res = evaluate_on_pawsx_it(model_name, tokenizer, model, device, num_samples=num_samples)
399
+ elif dataset_key == "sentiment_it":
400
+ res = evaluate_on_sentiment_it(model_name, tokenizer, model, device, num_samples=num_samples)
401
+ else:
402
+ raise ValueError(f"Dataset non supportato: {dataset_key}")
403
+
404
+ total_time = time.time() - start_total
405
+ res["total_time_sec"] = total_time
406
+ return res
407
+
408
+
409
  # =========================
410
  # Funzioni per la UI
411
  # =========================
412
 
413
  def add_model_field(current_count):
 
 
 
414
  if current_count < MAX_MODELS:
415
  current_count += 1
416
  return current_count
417
 
418
 
419
  def get_visible_textboxes(model_count):
 
 
 
420
  visibility = []
421
  for i in range(1, MAX_MODELS + 1):
422
  visibility.append(gr.update(visible=(i <= model_count)))
423
  return visibility
424
 
425
 
426
+ def add_dataset_field(current_count):
427
+ if current_count < MAX_DATASETS:
428
+ current_count += 1
429
+ return current_count
430
+
431
+
432
+ def get_visible_datasets(dataset_count):
433
+ visibility = []
434
+ for i in range(1, MAX_DATASETS + 1):
435
+ visibility.append(gr.update(visible=(i <= dataset_count)))
436
+ return visibility
437
+
438
+
439
  def run_benchmark_ui(
440
  model_1,
441
  model_2,
 
443
  model_4,
444
  model_5,
445
  model_count,
446
+ dataset_1,
447
+ dataset_2,
448
+ dataset_3,
449
+ dataset_4,
450
+ dataset_5,
451
+ dataset_count,
452
  num_samples,
453
  ):
454
+ # Raccogli modelli
 
 
 
 
 
 
455
  model_names = []
456
  all_models = [model_1, model_2, model_3, model_4, model_5]
457
  for i in range(model_count):
 
459
  if name:
460
  model_names.append(name)
461
 
462
+ # Raccogli dataset
463
+ dataset_labels = []
464
+ all_datasets = [dataset_1, dataset_2, dataset_3, dataset_4, dataset_5]
465
+ for i in range(dataset_count):
466
+ label = all_datasets[i]
467
+ if label in LABEL_TO_KEY:
468
+ dataset_labels.append(label)
469
+
470
  if len(model_names) < 2:
471
+ return pd.DataFrame(), "Devi specificare almeno due modelli validi."
472
+
473
+ if len(dataset_labels) < 1:
474
+ return pd.DataFrame(), "Devi selezionare almeno un dataset."
475
 
 
476
  logs = []
477
+ results = []
478
 
479
+ logs.append(f"Avvio benchmark con {num_samples} esempi per dataset...")
480
  logs.append(f"Modelli: {', '.join(model_names)}")
481
+ logs.append(f"Dataset: {', '.join(dataset_labels)}")
482
  logs.append("Device: " + get_device())
483
  logs.append("====================================")
484
 
485
+ for model_name in model_names:
486
+ logs.append(f"\n[MODELLO] {model_name}")
487
  try:
488
+ tokenizer, model, device = load_model(model_name)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
489
  except Exception as e:
490
+ logs.append(f" ERRORE nel caricamento del modello: {e}")
491
+ continue
492
+
493
+ for dlabel in dataset_labels:
494
+ dkey = LABEL_TO_KEY[dlabel]
495
+ logs.append(f" [DATASET] {dlabel}")
496
+ try:
497
+ res = evaluate_model_on_dataset(
498
+ model_name, tokenizer, model, device, dkey, num_samples
499
+ )
500
+ results.append(res)
501
+
502
+ avg_time_str = (
503
+ f"{res['avg_time_per_sample_sec']:.3f}"
504
+ if res["avg_time_per_sample_sec"] is not None
505
+ else "N/A"
506
+ )
507
+
508
+ logs.append(
509
+ f" - Esempi valutati: {res['num_samples']}\n"
510
+ f" - Accuracy: {res['accuracy']:.3f}\n"
511
+ f" - Tempo medio per esempio (s): {avg_time_str}\n"
512
+ f" - Tempo totale (s): {res['total_time_sec']:.3f}"
513
+ )
514
+ except Exception as e:
515
+ logs.append(f" ERRORE durante il benchmark: {e}")
516
 
517
  if results:
518
  df = pd.DataFrame(results)
519
+ df = df.sort_values(by=["dataset", "accuracy"], ascending=[True, False])
 
520
  else:
521
  df = pd.DataFrame()
522
 
 
525
 
526
 
527
  # =========================
528
+ # Interfaccia Gradio
529
  # =========================
530
 
531
+ with gr.Blocks(title="LLM Benchmark Space - Multi-dataset") as demo:
532
  gr.Markdown(
533
  """
534
+ # 🔍 LLM Benchmark Space (multi-dataset)
535
 
536
+ Inserisci i nomi dei modelli Hugging Face (es. `Mattimax/DAC4.3`)
537
+ e confrontali su uno o più dataset selezionabili da menu a tendina.
538
 
539
  - Minimo **2 modelli**
540
  - Puoi aggiungere fino a **5 modelli** con il pulsante **"+ Aggiungi modello"**
541
+ - Puoi selezionare **1 o più dataset** (fino a 5) con il pulsante **"+ Aggiungi dataset"**
542
+ - Output: tabella con **modello**, **dataset**, **accuracy**, numero di esempi e tempi
543
+
544
+ Dataset disponibili:
545
+ - BoolQ (en)
546
+ - SQuAD-it (it)
547
+ - PAWS-X (it)
548
+ - Sentiment-it (it)
549
  """
550
  )
551
 
552
  with gr.Row():
553
  with gr.Column():
554
+ # Stato numero modelli
555
  model_count_state = gr.State(value=2)
556
 
557
  model_1 = gr.Textbox(
 
585
  visible=False,
586
  )
587
 
588
+ add_model_button = gr.Button("+ Aggiungi modello")
589
+
590
+ # Stato numero dataset
591
+ dataset_count_state = gr.State(value=1)
592
+
593
+ dataset_1 = gr.Dropdown(
594
+ label="Dataset 1",
595
+ choices=DATASET_LABELS,
596
+ value="BoolQ (en)",
597
+ visible=True,
598
+ )
599
+ dataset_2 = gr.Dropdown(
600
+ label="Dataset 2",
601
+ choices=DATASET_LABELS,
602
+ value="SQuAD-it (it)",
603
+ visible=False,
604
+ )
605
+ dataset_3 = gr.Dropdown(
606
+ label="Dataset 3",
607
+ choices=DATASET_LABELS,
608
+ value="PAWS-X (it)",
609
+ visible=False,
610
+ )
611
+ dataset_4 = gr.Dropdown(
612
+ label="Dataset 4",
613
+ choices=DATASET_LABELS,
614
+ value="Sentiment-it (it)",
615
+ visible=False,
616
+ )
617
+ dataset_5 = gr.Dropdown(
618
+ label="Dataset 5",
619
+ choices=DATASET_LABELS,
620
+ value="BoolQ (en)",
621
+ visible=False,
622
+ )
623
+
624
+ add_dataset_button = gr.Button("+ Aggiungi dataset")
625
 
626
  num_samples = gr.Slider(
627
  minimum=10,
628
  maximum=200,
629
  step=10,
630
  value=DEFAULT_NUM_SAMPLES,
631
+ label="Numero di esempi per dataset",
632
  )
633
 
634
  run_button = gr.Button("🚀 Esegui benchmark", variant="primary")
 
637
  results_df = gr.Dataframe(
638
  headers=[
639
  "model_name",
640
+ "dataset",
641
  "num_samples",
642
  "accuracy",
643
  "avg_time_per_sample_sec",
 
648
  )
649
  logs_box = gr.Textbox(
650
  label="Log esecuzione",
651
+ lines=25,
652
  interactive=False,
653
  )
654
 
655
+ # Logica "+ Aggiungi modello"
656
  def on_add_model(model_count):
657
  new_count = add_model_field(model_count)
658
  visibility_updates = get_visible_textboxes(new_count)
659
  return [new_count] + visibility_updates
660
 
661
+ add_model_button.click(
662
  fn=on_add_model,
663
  inputs=[model_count_state],
664
  outputs=[model_count_state, model_1, model_2, model_3, model_4, model_5],
665
  )
666
 
667
+ # Logica "+ Aggiungi dataset"
668
+ def on_add_dataset(dataset_count):
669
+ new_count = add_dataset_field(dataset_count)
670
+ visibility_updates = get_visible_datasets(new_count)
671
+ return [new_count] + visibility_updates
672
+
673
+ add_dataset_button.click(
674
+ fn=on_add_dataset,
675
+ inputs=[dataset_count_state],
676
+ outputs=[dataset_count_state, dataset_1, dataset_2, dataset_3, dataset_4, dataset_5],
677
+ )
678
+
679
+ # Logica "Esegui benchmark"
680
  run_button.click(
681
  fn=run_benchmark_ui,
682
  inputs=[
 
686
  model_4,
687
  model_5,
688
  model_count_state,
689
+ dataset_1,
690
+ dataset_2,
691
+ dataset_3,
692
+ dataset_4,
693
+ dataset_5,
694
+ dataset_count_state,
695
  num_samples,
696
  ],
697
  outputs=[results_df, logs_box],