Zhiyu Wu commited on
Commit
36058af
1 Parent(s): ef7e22b

Add batching to benchmark.py (#13)

Browse files
Files changed (1) hide show
  1. scripts/benchmark.py +234 -19
scripts/benchmark.py CHANGED
@@ -6,13 +6,15 @@ import os
6
  import json
7
  import copy
8
  import atexit
9
- from typing import Generator, Literal
10
 
 
 
11
  import tyro
12
  import torch
13
  import rich
14
  from rich.table import Table
15
- from fastchat.serve.inference import generate_stream
16
  from fastchat.model.model_adapter import load_model, get_conversation_template
17
  from zeus.monitor import ZeusMonitor
18
 
@@ -37,6 +39,194 @@ SYSTEM_PROMPTS = {
37
  ),
38
  }
39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
  def main(
42
  model_path: str,
@@ -48,6 +238,7 @@ def main(
48
  temperature: float = 0.7,
49
  repitition_penalty: float = 1.0,
50
  max_new_tokens: int = 512,
 
51
  ) -> None:
52
  """Run benchmarking for one model on the entire input file.
53
 
@@ -125,7 +316,6 @@ def main(
125
  "max_new_tokens": max_new_tokens,
126
  "stop": conv_base.stop_str,
127
  "stop_token_ids": conv_base.stop_token_ids,
128
- "echo": False,
129
  }
130
 
131
  monitor = ZeusMonitor(gpu_indices=[torch.cuda.current_device()])
@@ -160,7 +350,7 @@ def main(
160
 
161
  def dataloader(input_file: str) -> Generator[tuple[bool, str], None, None]:
162
  """Yields a tuple of whether this is a warmup run and the input prompt."""
163
- for _ in range(3):
164
  yield True, "Say something long and random. I don't care about the content."
165
  for item in json.load(open(input_file, "r")):
166
  input_prompt = item["conversations"][0]["value"]
@@ -169,46 +359,65 @@ def main(
169
  # Warm up the GPU with some random prompts.
170
  # Forward through all the prompts.
171
  is_first = True
172
- for is_warmup, input_prompt in dataloader(input_file):
 
 
 
 
 
 
 
 
 
 
173
  # Construct the input prompt.
174
- conv = copy.deepcopy(conv_base)
175
- conv.append_message(conv.roles[0], input_prompt)
176
- conv.append_message(conv.roles[1], "")
177
- prompt = conv.get_prompt()
178
- gen_params["prompt"] = prompt
 
 
 
 
179
 
180
  # Print input prompt.
181
- console.print(f"\n[u cyan]{'Warmup ' if is_warmup else ''}Prompt[/u cyan]:")
182
- console.print(prompt.strip() + "\n", markup=False)
 
183
 
184
  # Generate the ouptut from the model.
185
- output_stream = generate_stream(model, tokenizer, gen_params, device="cuda")
186
  output = {}
 
187
 
188
  #################################################
189
  # Inference and measurement zone!
190
  #################################################
191
  monitor.begin_window("inference")
192
  for output in output_stream:
193
- pass
 
 
194
  measurements = monitor.end_window("inference")
195
  #################################################
196
 
197
  # Record numbers.
198
  output_text = output["text"]
199
  if not is_warmup:
200
- response_length = len(tokenizer.encode(output_text)) # number of tokens
201
  latency = measurements.time
202
  throughput = response_length / latency
203
  energy = measurements.total_energy
204
  output = {
205
  "model": model_name_cleaned,
 
206
  "throughput": throughput,
207
  "response_length": response_length,
208
  "latency": latency,
209
  "energy": energy,
210
- "input": prompt.strip(),
211
- "output": output_text.strip(),
212
  }
213
  output_str = json.dumps(output, indent=4)
214
  if not is_warmup:
@@ -220,11 +429,17 @@ def main(
220
  output_json.flush()
221
 
222
  # Print the response.
223
- console.print(f"\n[u cyan]{'Warmup ' if is_warmup else ''}Response[/u cyan]:")
224
- console.print(output_text.strip() + "\n", markup=False)
 
225
 
226
  # Print measurement.
227
  console.print(measurements)
 
 
 
 
 
228
 
229
 
230
  if __name__ == "__main__":
 
6
  import json
7
  import copy
8
  import atexit
9
+ from typing import Generator, Literal, Iterable, Dict
10
 
11
+ import gc
12
+ import numpy as np
13
  import tyro
14
  import torch
15
  import rich
16
  from rich.table import Table
17
+ from fastchat.serve.inference import prepare_logits_processor
18
  from fastchat.model.model_adapter import load_model, get_conversation_template
19
  from zeus.monitor import ZeusMonitor
20
 
 
39
  ),
40
  }
41
 
42
+ def is_partial_stop(output: str, stop_str: str):
43
+ """Check whether the output contains a partial stop str."""
44
+ for i in range(0, min(len(output), len(stop_str))):
45
+ if stop_str.startswith(output[-i:]):
46
+ return True
47
+ return False
48
+
49
+ @torch.inference_mode()
50
+ def generate_stream(
51
+ model,
52
+ tokenizer,
53
+ params: Dict,
54
+ device: str,
55
+ context_len: int = 2048,
56
+ ):
57
+ # Read parameters
58
+ prompts = params["prompt"]
59
+ temperature = float(params.get("temperature", 1.0))
60
+ repetition_penalty = float(params.get("repetition_penalty", 1.0))
61
+ top_p = float(params.get("top_p", 1.0))
62
+ top_k = int(params.get("top_k", -1)) # -1 means disable
63
+ max_new_tokens = int(params.get("max_new_tokens", 256))
64
+ stop_str = params.get("stop", None)
65
+ stop_token_ids = params.get("stop_token_ids", None) or []
66
+ stop_token_ids.append(tokenizer.eos_token_id)
67
+ batch_size = len(prompts)
68
+
69
+ # left append prompts with eos to make all input prompts the same length
70
+ tokenizer.padding_side = "left"
71
+ tokenizer.pad_token = tokenizer.eos_token
72
+
73
+ logits_processor = prepare_logits_processor(
74
+ temperature, repetition_penalty, top_p, top_k
75
+ )
76
+
77
+ input_ids = tokenizer(prompts, padding=True).input_ids
78
+ output_ids = list(input_ids)
79
+
80
+ if model.config.is_encoder_decoder:
81
+ max_src_len = context_len
82
+ else: # truncate
83
+ max_src_len = context_len - max_new_tokens - 8
84
+
85
+ input_ids = [input_id[-max_src_len:] for input_id in input_ids]
86
+ input_len = len(input_ids[0])
87
+
88
+ if model.config.is_encoder_decoder:
89
+ encoder_output = model.encoder(
90
+ input_ids=torch.as_tensor(input_ids, device=device)
91
+ )[0]
92
+ start_ids = torch.as_tensor(
93
+ [[model.generation_config.decoder_start_token_id] for _ in range(batch_size)],
94
+ dtype=torch.int64,
95
+ device=device,
96
+ )
97
+
98
+ past_key_values = out = None
99
+ stopped = np.array(batch_size*[False])
100
+ for i in range(max_new_tokens):
101
+ if i == 0: # prefill
102
+ if model.config.is_encoder_decoder:
103
+ out = model.decoder(
104
+ input_ids=start_ids,
105
+ encoder_hidden_states=encoder_output,
106
+ use_cache=True,
107
+ )
108
+ logits = model.lm_head(out[0])
109
+ else:
110
+ out = model(torch.as_tensor(input_ids, device=device), use_cache=True)
111
+ logits = out.logits
112
+ past_key_values = out.past_key_values
113
+ else: # decoding
114
+ if model.config.is_encoder_decoder:
115
+ out = model.decoder(
116
+ input_ids=torch.as_tensor(
117
+ [[token[0]] for token in tokens], device=device
118
+ ),
119
+ encoder_hidden_states=encoder_output,
120
+ use_cache=True,
121
+ past_key_values=past_key_values,
122
+ )
123
+ logits = model.lm_head(out[0])
124
+ else:
125
+ out = model(
126
+ input_ids=torch.as_tensor(
127
+ [[token[0]] for token in tokens], device=device
128
+ ),
129
+ use_cache=True,
130
+ past_key_values=past_key_values,
131
+ )
132
+ logits = out.logits
133
+ past_key_values = out.past_key_values
134
+
135
+ if logits_processor:
136
+ if repetition_penalty > 1.0:
137
+ tmp_output_ids = torch.as_tensor(output_ids, device=logits.device)
138
+ else:
139
+ tmp_output_ids = None
140
+ last_token_logits = logits_processor(tmp_output_ids, logits[:, -1, :])
141
+ else:
142
+ last_token_logits = logits[:, -1, :]
143
+
144
+ if device == "mps":
145
+ # Switch to CPU by avoiding some bugs in mps backend.
146
+ last_token_logits = last_token_logits.float().to("cpu")
147
+
148
+ if temperature < 1e-5 or top_p < 1e-8: # greedy
149
+ _, indices = torch.topk(last_token_logits, 2)
150
+ tokens = [[int(token) for token in query] for query in indices.tolist()]
151
+ else:
152
+ probs = torch.softmax(last_token_logits, dim=-1)
153
+ indices = torch.multinomial(probs, num_samples=2)
154
+ tokens = [[int(token) for token in query] for query in indices.tolist()]
155
+
156
+ old_stopped = stopped
157
+ stopped = np.logical_or(old_stopped, np.array([True if token[0] in stop_token_ids else False for token in tokens]))
158
+ output_ids = [ids + [token[0]] for ids, token in zip(output_ids, tokens)]
159
+
160
+ def slice(s, pos):
161
+ return s[:pos]
162
+ vec_slice = np.vectorize(slice, otypes=[str])
163
+ vec_is_partial_stop = np.vectorize(is_partial_stop)
164
+
165
+ # Yield the output tokens
166
+ if any(stopped):
167
+ tmp_output_ids = [ids[input_len:] for ids in output_ids]
168
+ rfind_start = 0
169
+ output = tokenizer.batch_decode(
170
+ tmp_output_ids,
171
+ skip_special_tokens=True,
172
+ spaces_between_special_tokens=False,
173
+ clean_up_tokenization_spaces=True,
174
+ )
175
+ output = np.array(output)
176
+
177
+ partially_stopped = np.array(len(output_ids) * [False])
178
+ different_indices = np.empty(shape=(0,))
179
+ if stop_str:
180
+ if isinstance(stop_str, str):
181
+ pos_array = np.char.rfind(output, stop_str, rfind_start)
182
+ find_stop = pos_array != -1
183
+ output[find_stop] = vec_slice(output[find_stop], pos_array[find_stop])
184
+ stopped = find_stop
185
+ partially_stopped = vec_is_partial_stop(output, stop_str)
186
+ elif isinstance(stop_str, Iterable):
187
+ for each_stop in stop_str:
188
+ pos_array = np.char.rfind(output, stop_str, rfind_start)
189
+ find_stop = pos_array != -1
190
+ output[find_stop] = vec_slice(output[find_stop], pos_array[find_stop])
191
+ stopped = find_stop
192
+ partially_stopped = partially_stopped | vec_is_partial_stop(output, each_stop)
193
+ else:
194
+ raise ValueError("Invalid stop field type.")
195
+
196
+ # Prevent yielding partial stop sequence
197
+ if not any(partially_stopped):
198
+ # indicates which request in batch stopped
199
+ different_indices = np.where(stopped != old_stopped)[0]
200
+ stop_length = np.array([(i, len(output[i])) for i in different_indices])
201
+ yield {
202
+ "text": output,
203
+ "stop_length": stop_length,
204
+ }
205
+
206
+ if all(stopped):
207
+ break
208
+
209
+ false_indices = np.where(stopped == False)[0]
210
+ if any(stopped) == False:
211
+ tmp_output_ids = [ids[input_len:] for ids in output_ids]
212
+ output = tokenizer.batch_decode(
213
+ tmp_output_ids,
214
+ skip_special_tokens=True,
215
+ spaces_between_special_tokens=False,
216
+ clean_up_tokenization_spaces=True,
217
+ )
218
+ stop_length = np.array([(i, len(output[i])) for i in false_indices])
219
+
220
+ yield {
221
+ "text": output,
222
+ "stop_length": stop_length,
223
+ }
224
+
225
+ # Clean
226
+ del past_key_values, out
227
+ gc.collect()
228
+ torch.cuda.empty_cache()
229
+
230
 
231
  def main(
232
  model_path: str,
 
238
  temperature: float = 0.7,
239
  repitition_penalty: float = 1.0,
240
  max_new_tokens: int = 512,
241
+ batch: int = 1,
242
  ) -> None:
243
  """Run benchmarking for one model on the entire input file.
244
 
 
316
  "max_new_tokens": max_new_tokens,
317
  "stop": conv_base.stop_str,
318
  "stop_token_ids": conv_base.stop_token_ids,
 
319
  }
320
 
321
  monitor = ZeusMonitor(gpu_indices=[torch.cuda.current_device()])
 
350
 
351
  def dataloader(input_file: str) -> Generator[tuple[bool, str], None, None]:
352
  """Yields a tuple of whether this is a warmup run and the input prompt."""
353
+ for _ in range(3*batch):
354
  yield True, "Say something long and random. I don't care about the content."
355
  for item in json.load(open(input_file, "r")):
356
  input_prompt = item["conversations"][0]["value"]
 
359
  # Warm up the GPU with some random prompts.
360
  # Forward through all the prompts.
361
  is_first = True
362
+ convs = []
363
+ prompts = []
364
+ data_iter = iter(dataloader(input_file))
365
+
366
+ end_of_file = False # flag to track the end of the file
367
+ while True:
368
+ try:
369
+ is_warmup, input_prompt = next(data_iter)
370
+ except StopIteration:
371
+ end_of_file = True # no more data
372
+
373
  # Construct the input prompt.
374
+ if not end_of_file:
375
+ conv = copy.deepcopy(conv_base)
376
+ conv.append_message(conv.roles[0], input_prompt)
377
+ conv.append_message(conv.roles[1], "")
378
+ prompt = conv.get_prompt()
379
+ prompts.append(prompt)
380
+ convs.append(conv)
381
+ if (len(convs) < batch): continue
382
+ gen_params["prompt"] = prompts
383
 
384
  # Print input prompt.
385
+ for i in range(len(convs)):
386
+ console.print(f"\n[u cyan]{'Warmup ' if is_warmup else ''}Prompt[/u cyan](batch_{i}):")
387
+ console.print(prompts[i].strip() + "\n", markup=False)
388
 
389
  # Generate the ouptut from the model.
390
+ output_stream = generate_stream(model, tokenizer, gen_params, device="cuda", context_len=2048)
391
  output = {}
392
+ batch_token_len = {}
393
 
394
  #################################################
395
  # Inference and measurement zone!
396
  #################################################
397
  monitor.begin_window("inference")
398
  for output in output_stream:
399
+ stop_length = output["stop_length"]
400
+ for it in stop_length:
401
+ batch_token_len[it[0]] = it[1]
402
  measurements = monitor.end_window("inference")
403
  #################################################
404
 
405
  # Record numbers.
406
  output_text = output["text"]
407
  if not is_warmup:
408
+ response_length = int(sum(batch_token_len.values())) # number of valid tokens
409
  latency = measurements.time
410
  throughput = response_length / latency
411
  energy = measurements.total_energy
412
  output = {
413
  "model": model_name_cleaned,
414
+ "batch": len(convs),
415
  "throughput": throughput,
416
  "response_length": response_length,
417
  "latency": latency,
418
  "energy": energy,
419
+ "input": [prompt.strip() for prompt in prompts],
420
+ "output": [output_text[i][:batch_token_len[i]].strip() for i in range(len(convs))],
421
  }
422
  output_str = json.dumps(output, indent=4)
423
  if not is_warmup:
 
429
  output_json.flush()
430
 
431
  # Print the response.
432
+ for i in range(len(convs)):
433
+ console.print(f"\n[u cyan]{'Warmup ' if is_warmup else ''}Response[/u cyan](batch_{i}):")
434
+ console.print(output_text[i][:batch_token_len[i]].strip() + "\n", markup=False)
435
 
436
  # Print measurement.
437
  console.print(measurements)
438
+ convs = []
439
+ prompts = []
440
+
441
+ if end_of_file:
442
+ break
443
 
444
 
445
  if __name__ == "__main__":