Jae-Won Chung commited on
Commit
e3f95b1
·
1 Parent(s): e8521fb

Disable markup when printing input/outputs

Browse files
Files changed (1) hide show
  1. benchmark.py +9 -6
benchmark.py CHANGED
@@ -69,6 +69,9 @@ def main(
69
  if "chatglm" in model_path.lower():
70
  raise ValueError("ChatGLM is not supported.")
71
 
 
 
 
72
  # Print out what we're about to do.
73
  if model_path.endswith("/"):
74
  model_path = model_path[:-1]
@@ -89,7 +92,7 @@ def main(
89
  table.add_row("Max New Tokens", str(max_new_tokens))
90
  table.add_row("Output CSV", output_csv_path)
91
  table.add_row("Config JSON", config_json_path)
92
- rich.get_console().print(table)
93
 
94
  # Set the device.
95
  torch.cuda.set_device(f"cuda:{device_index}")
@@ -175,8 +178,8 @@ def main(
175
  gen_params["prompt"] = prompt
176
 
177
  # Print input prompt.
178
- rich.print(f"\n[u cyan]{'Warmup ' if is_warmup else ''}Prompt[/u cyan]:")
179
- rich.get_console().print(prompt.strip() + "\n", markup=False)
180
 
181
  # Generate the ouptut from the model.
182
  output_stream = generate_stream(model, tokenizer, gen_params, device="cuda")
@@ -217,11 +220,11 @@ def main(
217
  output_json.flush()
218
 
219
  # Print the response.
220
- rich.print(f"\n[u cyan]{'Warmup ' if is_warmup else ''}Response[/u cyan]:")
221
- rich.get_console().print(output_text.strip() + "\n")
222
 
223
  # Print measurement.
224
- rich.print(measurements)
225
 
226
 
227
  if __name__ == "__main__":
 
69
  if "chatglm" in model_path.lower():
70
  raise ValueError("ChatGLM is not supported.")
71
 
72
+ # Get Rich Console instance.
73
+ console = rich.get_console()
74
+
75
  # Print out what we're about to do.
76
  if model_path.endswith("/"):
77
  model_path = model_path[:-1]
 
92
  table.add_row("Max New Tokens", str(max_new_tokens))
93
  table.add_row("Output CSV", output_csv_path)
94
  table.add_row("Config JSON", config_json_path)
95
+ console.print(table)
96
 
97
  # Set the device.
98
  torch.cuda.set_device(f"cuda:{device_index}")
 
178
  gen_params["prompt"] = prompt
179
 
180
  # Print input prompt.
181
+ console.print(f"\n[u cyan]{'Warmup ' if is_warmup else ''}Prompt[/u cyan]:")
182
+ console.print(prompt.strip() + "\n", markup=False)
183
 
184
  # Generate the ouptut from the model.
185
  output_stream = generate_stream(model, tokenizer, gen_params, device="cuda")
 
220
  output_json.flush()
221
 
222
  # Print the response.
223
+ console.print(f"\n[u cyan]{'Warmup ' if is_warmup else ''}Response[/u cyan]:")
224
+ console.print(output_text.strip() + "\n", markup=False)
225
 
226
  # Print measurement.
227
+ console.print(measurements)
228
 
229
 
230
  if __name__ == "__main__":