AppleSwing commited on
Commit
4045483
1 Parent(s): 28b6090

Fix bugs in gsm8k

Browse files
backend-cli.py CHANGED
@@ -152,7 +152,7 @@ def process_evaluation(task: Task, eval_request: EvalRequest, limit: Optional[in
152
  monitor_thread.start()
153
 
154
  original_apply = RegexFilter.apply
155
- if task.benchmark == "gsm8k":
156
  RegexFilter.apply = tuple_input_decorator(RegexFilter.apply)
157
  else:
158
  RegexFilter.apply = original_apply
 
152
  monitor_thread.start()
153
 
154
  original_apply = RegexFilter.apply
155
+ if task.benchmark in ["gsm8k", "gsm8k_cot", "gsm8k_cot_self_consistency", "gsm8k_custom"]:
156
  RegexFilter.apply = tuple_input_decorator(RegexFilter.apply)
157
  else:
158
  RegexFilter.apply = original_apply
src/backend/envs.py CHANGED
@@ -57,7 +57,7 @@ class Tasks(Enum):
57
 
58
  # task20 = Task("race", "acc", "RACE", 0)
59
  task21 = Task("mmlu", "acc", "MMLU", 5)
60
- task22 = Task("gsm8k", "em", "GSM8K", 5)
61
 
62
 
63
  EVAL_REQUESTS_PATH_BACKEND = os.path.join(CACHE_PATH, "eval-queue-bk")
 
57
 
58
  # task20 = Task("race", "acc", "RACE", 0)
59
  task21 = Task("mmlu", "acc", "MMLU", 5)
60
+ task22 = Task("gsm8k_custom", "em", "GSM8K", 5)
61
 
62
 
63
  EVAL_REQUESTS_PATH_BACKEND = os.path.join(CACHE_PATH, "eval-queue-bk")
src/backend/hflm_with_measurement.py CHANGED
@@ -295,6 +295,8 @@ class HFLMWithMeasurement(HFLM):
295
  # and we don't want a warning from HF
296
  generation_kwargs["temperature"] = generation_kwargs.get("temperature", 0.0)
297
  do_sample = generation_kwargs.get("do_sample", None)
 
 
298
 
299
  # The temperature has to be a strictly positive float -- if it is 0.0, use greedy decoding strategies
300
  if generation_kwargs.get("temperature") == 0.0 and do_sample is None:
@@ -302,22 +304,40 @@ class HFLMWithMeasurement(HFLM):
302
 
303
  if do_sample is False and generation_kwargs.get("temperature") == 0.0:
304
  generation_kwargs.pop("temperature")
 
 
 
 
305
  # build stopping criteria
306
- stopping_criteria = stop_sequences_criteria(
307
- self.tokenizer, stop, context.shape[1], context.shape[0]
308
- )
309
- stop_watch = StopWatch(self.tokenizer)
310
- start = time()
311
- res = self.model.generate(
312
- input_ids=context,
313
- max_length=max_length,
314
- stopping_criteria=stopping_criteria,
315
- pad_token_id=self.tokenizer.pad_token_id,
316
- use_cache=True,
317
- streamer=stop_watch,
318
- **generation_kwargs,
319
- )
320
- end = time()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
321
 
322
  batch_size = context.shape[0]
323
  output_length = stop_watch.decoding_iterations
@@ -408,6 +428,11 @@ class HFLMWithMeasurement(HFLM):
408
  until = [eos]
409
  else:
410
  until.append(eos)
 
 
 
 
 
411
  if "max_gen_toks" in kwargs.keys():
412
  max_gen_toks = kwargs.pop("max_gen_toks")
413
  else:
@@ -427,6 +452,8 @@ class HFLMWithMeasurement(HFLM):
427
  left_truncate_len=max_ctx_len,
428
  truncation=self.truncation,
429
  )
 
 
430
  context_enc = context_enc.to(self.device)
431
  attn_masks = attn_masks.to(self.device)
432
 
@@ -445,16 +472,18 @@ class HFLMWithMeasurement(HFLM):
445
  for cont_toks, context in zip(cont_toks_list, contexts):
446
  # discard context + left-padding toks if using causal decoder-only LM
447
  if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
 
448
  cont_toks = cont_toks[context_enc.shape[1] :]
449
-
450
  s = self.tok_decode(cont_toks)
451
 
452
  # use secondary stop seqs to cut off should-have-been-stopped content post-hoc
453
- for term in until:
454
- if len(term) > 0:
455
- # ignore '' separator,
456
- # for seq2seq case where self.tok_decode(self.eot_token_id) = ''
457
- s = s.split(term)[0]
 
458
 
459
  res.append((s, end_to_end_time, prefilling_time, token_per_sec))
460
 
 
295
  # and we don't want a warning from HF
296
  generation_kwargs["temperature"] = generation_kwargs.get("temperature", 0.0)
297
  do_sample = generation_kwargs.get("do_sample", None)
298
+
299
+ is_gsm8k = generation_kwargs.get("is_gsm8k", False)
300
 
301
  # The temperature has to be a strictly positive float -- if it is 0.0, use greedy decoding strategies
302
  if generation_kwargs.get("temperature") == 0.0 and do_sample is None:
 
304
 
305
  if do_sample is False and generation_kwargs.get("temperature") == 0.0:
306
  generation_kwargs.pop("temperature")
307
+
308
+ generation_kwargs.pop("is_gsm8k")
309
+
310
+ if not is_gsm8k:
311
  # build stopping criteria
312
+ stopping_criteria = stop_sequences_criteria(
313
+ self.tokenizer, stop, context.shape[1], context.shape[0]
314
+ )
315
+ stop_watch = StopWatch(self.tokenizer)
316
+ start = time()
317
+ res = self.model.generate(
318
+ input_ids=context,
319
+ max_length=max_length,
320
+ stopping_criteria=stopping_criteria,
321
+ pad_token_id=self.tokenizer.pad_token_id,
322
+ use_cache=True,
323
+ streamer=stop_watch,
324
+ **generation_kwargs,
325
+ )
326
+ end = time()
327
+ else:
328
+ # print("Using GSM8K")
329
+ stop_watch = StopWatch(self.tokenizer)
330
+ start = time()
331
+ res = self.model.generate(
332
+ input_ids=context,
333
+ max_length=max_length,
334
+ eos_token_id=stop,
335
+ pad_token_id=self.tokenizer.pad_token_id,
336
+ use_cache=True,
337
+ streamer=stop_watch,
338
+ **generation_kwargs,
339
+ )
340
+ end = time()
341
 
342
  batch_size = context.shape[0]
343
  output_length = stop_watch.decoding_iterations
 
428
  until = [eos]
429
  else:
430
  until.append(eos)
431
+
432
+ is_gsm8k = kwargs.get("is_gsm8k", False)
433
+ if is_gsm8k:
434
+ until = [self.tokenizer.eos_token_id, self.tokenizer.convert_tokens_to_ids("<|eot_id|>")]
435
+
436
  if "max_gen_toks" in kwargs.keys():
437
  max_gen_toks = kwargs.pop("max_gen_toks")
438
  else:
 
452
  left_truncate_len=max_ctx_len,
453
  truncation=self.truncation,
454
  )
455
+
456
+ # print("context: ", self.tok_decode(context_enc[0]))
457
  context_enc = context_enc.to(self.device)
458
  attn_masks = attn_masks.to(self.device)
459
 
 
472
  for cont_toks, context in zip(cont_toks_list, contexts):
473
  # discard context + left-padding toks if using causal decoder-only LM
474
  if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
475
+ # print("After Generation: ", self.tok_decode(cont_toks))
476
  cont_toks = cont_toks[context_enc.shape[1] :]
477
+
478
  s = self.tok_decode(cont_toks)
479
 
480
  # use secondary stop seqs to cut off should-have-been-stopped content post-hoc
481
+ if not is_gsm8k:
482
+ for term in until:
483
+ if len(term) > 0:
484
+ # ignore '' separator,
485
+ # for seq2seq case where self.tok_decode(self.eot_token_id) = ''
486
+ s = s.split(term)[0]
487
 
488
  res.append((s, end_to_end_time, prefilling_time, token_per_sec))
489
 
src/backend/tasks/gsm8k/gsm8k-custom.yaml ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ group:
2
+ - math_word_problems
3
+ task: gsm8k_custom
4
+ dataset_path: gsm8k
5
+ dataset_name: main
6
+ output_type: generate_until
7
+ training_split: train
8
+ fewshot_split: train
9
+ test_split: test
10
+ doc_to_text: "Question: {{question}}\nAnswer:"
11
+ doc_to_target: "{{answer}}" #" {{answer.split('### ')[-1].rstrip()}}"
12
+ metric_list:
13
+ - metric: exact_match
14
+ aggregation: mean
15
+ higher_is_better: true
16
+ ignore_case: true
17
+ ignore_punctuation: false
18
+ regexes_to_ignore:
19
+ - ","
20
+ - "\\$"
21
+ - "(?s).*#### "
22
+ - "\\.$"
23
+ generation_kwargs:
24
+ until:
25
+ - "<|eot_id|>"
26
+ do_sample: false
27
+ temperature: 0.0
28
+ is_gsm8k: true
29
+ repeats: 1
30
+ num_fewshot: 5
31
+ filter_list:
32
+ # - name: "strict-match"
33
+ # filter:
34
+ # - function: "regex"
35
+ # regex_pattern: "#### (\\-?[0-9\\.\\,]+)"
36
+ # - function: "take_first"
37
+ - name: "flexible-extract"
38
+ filter:
39
+ - function: "regex"
40
+ group_select: -1
41
+ regex_pattern: "(-?[$0-9.,]{2,})|(-?[0-9]+)"
42
+ - function: "take_first"
43
+ metadata:
44
+ version: 3.0
src/display/utils.py CHANGED
@@ -75,7 +75,7 @@ class Tasks(Enum):
75
  # # XXX include me back at some point
76
  selfcheck = Task("selfcheckgpt", "max-selfcheckgpt", "SelfCheckGPT")
77
  mmlu = Task("mmlu", "acc", "MMLU") #MMLU/Acc (5-shot)
78
- gsm8k = Task("gsm8k", "em", "GSM8K") #GSM8K/EM (5-shot)
79
 
80
 
81
  # These classes are for user facing column names,
 
75
  # # XXX include me back at some point
76
  selfcheck = Task("selfcheckgpt", "max-selfcheckgpt", "SelfCheckGPT")
77
  mmlu = Task("mmlu", "acc", "MMLU") #MMLU/Acc (5-shot)
78
+ gsm8k = Task("gsm8k_custom", "em", "GSM8K") #GSM8K/EM (8-shot)
79
 
80
 
81
  # These classes are for user facing column names,