lhoestq HF staff commited on
Commit
61755fe
β€’
1 Parent(s): a563465
Files changed (2) hide show
  1. generate.py +10 -5
  2. gradio_app.py +2 -2
generate.py CHANGED
@@ -38,11 +38,11 @@ sampler.set_max_repeats(empty_tokens, 1)
38
 
39
  class Sample(BaseModel):
40
  # We use get_samples_generator() to replace the placeholder with the requested fields
41
- ABCDabcd: str
42
- EFGHefgh: str
43
- IJKLijkl: str
44
- MNOPmnop: str
45
- QRSTqrst: str
46
  # PS: don't use StringConstraints with max_length here since it creates a fsm that is too big
47
 
48
 
@@ -110,6 +110,11 @@ def stream_file(filename: str, prompt: str, columns: list[str], seed: int, size:
110
  columns.append(column)
111
  logger.warning(f"stream_response({filename=}, {prompt=}, {columns=}) Generating columns... DONE (total={time.time() - _start:.02f}s)")
112
 
 
 
 
 
 
113
  logger.warning(f"stream_response({filename=}, {prompt=}, {columns=}) - Generating JSON regex guide...")
114
  samples_generator = get_samples_generator(new_fields=columns)
115
  logger.warning(f"stream_response({filename=}, {prompt=}, {columns=}) - Generating JSON regex guide... DONE (total={time.time() - _start:.02f}s)")
 
38
 
39
  class Sample(BaseModel):
40
  # We use get_samples_generator() to replace the placeholder with the requested fields
41
+ ABCDabcd12: str
42
+ EFGHefgh34: str
43
+ IJKLijkl56: str
44
+ MNOPmnop78: str
45
+ QRSTqrst90: str
46
  # PS: don't use StringConstraints with max_length here since it creates a fsm that is too big
47
 
48
 
 
110
  columns.append(column)
111
  logger.warning(f"stream_response({filename=}, {prompt=}, {columns=}) Generating columns... DONE (total={time.time() - _start:.02f}s)")
112
 
113
+ columns = [
114
+ tokenizer.decode(tokenizer.encode(column, add_special_tokens=False)[:len(orig_field)], skip_special_tokens=True)
115
+ for column, orig_field in zip(columns, Sample.model_fields)
116
+ ]
117
+
118
  logger.warning(f"stream_response({filename=}, {prompt=}, {columns=}) - Generating JSON regex guide...")
119
  samples_generator = get_samples_generator(new_fields=columns)
120
  logger.warning(f"stream_response({filename=}, {prompt=}, {columns=}) - Generating JSON regex guide... DONE (total={time.time() - _start:.02f}s)")
gradio_app.py CHANGED
@@ -27,7 +27,7 @@ def stream_output(filename: str):
27
  state_msg = (
28
  f"βœ… Done generating {size} samples in {time.time() - start_time:.2f}s"
29
  if i + 1 == size else
30
- f"βš™οΈ Generating... [{i}/{size}]"
31
  )
32
  yield df, "```json\n" + content + "\n```", state_msg
33
 
@@ -45,7 +45,7 @@ def test(filename: str):
45
  state_msg = (
46
  f"βœ… Done generating {size} samples in {time.time() - start_time:.2f}s"
47
  if i + 1 == size else
48
- f"βš™οΈ Generating... [{i}/{size}]"
49
  )
50
  yield df, "```json\n" + content + "\n```", state_msg
51
  time.sleep(0.1)
 
27
  state_msg = (
28
  f"βœ… Done generating {size} samples in {time.time() - start_time:.2f}s"
29
  if i + 1 == size else
30
+ f"βš™οΈ Generating... [{i + 1}/{size}]"
31
  )
32
  yield df, "```json\n" + content + "\n```", state_msg
33
 
 
45
  state_msg = (
46
  f"βœ… Done generating {size} samples in {time.time() - start_time:.2f}s"
47
  if i + 1 == size else
48
+ f"βš™οΈ Generating... [{i + 1}/{size}]"
49
  )
50
  yield df, "```json\n" + content + "\n```", state_msg
51
  time.sleep(0.1)